Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion gcloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, dataset_id=None, connection=None):
'a dataset ID set.')

self._mutation = datastore_pb.Mutation()
self._auto_id_entities = []

@property
def dataset_id(self):
Expand Down Expand Up @@ -137,6 +138,9 @@ def put(self, entity):
self.dataset_id, key_pb, properties,
exclude_from_indexes=exclude, mutation=self.mutation)

if entity.key.is_partial:
self._auto_id_entities.append(entity)

def delete(self, key):
"""Remember a key to be deleted durring ``commit``.

Expand All @@ -159,7 +163,11 @@ def commit(self):
however it can be called explicitly if you don't want to use a
context manager.
"""
self.connection.commit(self._dataset_id, self.mutation)
response = self.connection.commit(self._dataset_id, self.mutation)
for new_key_pb, entity in zip(response.insert_auto_id_key,
self._auto_id_entities):

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

new_id = new_key_pb.path_element[-1].id
entity.key = entity.key.completed_key(new_id)

def __enter__(self):
return self
Expand Down
59 changes: 55 additions & 4 deletions gcloud/datastore/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_ctor_explicit(self):
self.assertEqual(batch.dataset_id, _DATASET)
self.assertEqual(batch.connection, connection)
self.assertTrue(isinstance(batch.mutation, Mutation))
self.assertEqual(batch._auto_id_entities, [])

def test_ctor_implicit(self):
from gcloud._testing import _Monkey
Expand All @@ -62,6 +63,7 @@ def test_ctor_implicit(self):
self.assertEqual(batch.dataset_id, DATASET_ID)
self.assertEqual(batch.connection, CONNECTION)
self.assertTrue(isinstance(batch.mutation, Mutation))
self.assertEqual(batch._auto_id_entities, [])

def test_put_entity_wo_key(self):
_DATASET = 'DATASET'
Expand All @@ -70,7 +72,23 @@ def test_put_entity_wo_key(self):

self.assertRaises(ValueError, batch.put, _Entity())

def test_put_entity_w_key(self):
def test_put_entity_w_partial_key(self):
_DATASET = 'DATASET'
_PROPERTIES = {'foo': 'bar'}
connection = _Connection()
batch = self._makeOne(dataset_id=_DATASET, connection=connection)
entity = _Entity(_PROPERTIES)
key = entity.key = _Key(_DATASET)
key._partial = True

batch.put(entity)

self.assertEqual(
connection._saved,
(_DATASET, key._key, _PROPERTIES, (), batch.mutation))
self.assertEqual(batch._auto_id_entities, [entity])

def test_put_entity_w_completed_key(self):
_DATASET = 'DATASET'
_PROPERTIES = {'foo': 'bar'}
connection = _Connection()
Expand Down Expand Up @@ -114,6 +132,22 @@ def test_commit(self):

self.assertEqual(connection._committed, (_DATASET, batch.mutation))

def test_commit_w_auto_id_entities(self):
_DATASET = 'DATASET'
_NEW_ID = 1234
connection = _Connection(_NEW_ID)
batch = self._makeOne(dataset_id=_DATASET, connection=connection)
entity = _Entity({})
key = entity.key = _Key(_DATASET)
key._partial = True
batch._auto_id_entities.append(entity)

batch.commit()

self.assertEqual(connection._committed, (_DATASET, batch.mutation))
self.assertFalse(key._partial)
self.assertEqual(key._id, _NEW_ID)

def test_as_context_mgr_wo_error(self):
_DATASET = 'DATASET'
_PROPERTIES = {'foo': 'bar'}
Expand Down Expand Up @@ -154,16 +188,28 @@ def test_as_context_mgr_w_error(self):
class _CommitResult(object):

def __init__(self, *new_keys):
self.insert_auto_id_key = new_keys
self.insert_auto_id_key = [_KeyPB(key) for key in new_keys]


class _PathElementPB(object):

def __init__(self, id):
self.id = id


class _KeyPB(object):

def __init__(self, id):
self.path_element = [_PathElementPB(id)]


class _Connection(object):
_marker = object()
_committed = _saved = _deleted = None
_save_result = (False, None)

def __init__(self):
self._commit_result = _CommitResult()
def __init__(self, *new_keys):
self._commit_result = _CommitResult(*new_keys)

def save_entity(self, dataset_id, key_pb, properties,
exclude_from_indexes=(), mutation=None):
Expand Down Expand Up @@ -201,3 +247,8 @@ def is_partial(self):

def to_protobuf(self):
return self._key

def completed_key(self, new_id):
assert self._partial
self._id = new_id
self._partial = False