Skip to content

Commit 6020ee7

Browse files
committed
Making datastore batch/transaction more robust to failure.
- Ensuring that `Batch` methods `put()`, `delete()`, `commit()` and `rollback()` are only called when the batch is in progress. - In `Batch.__enter__()` make sure the batch is only put on the stack after `begin()` succeeds. - `Client.delete_multi()` and `Client.put_multi()` (and downstream methods) now call `begin()` on new `Batch` (since required to be in progress). - `Transaction.begin()` if `begin_transaction()` API call fails, make sure to change the status to `ABORTED` before raising the exception from the failure. Fixes #2297.
1 parent 15e8aec commit 6020ee7

File tree

6 files changed

+132
-9
lines changed

6 files changed

+132
-9
lines changed

google/cloud/datastore/batch.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,13 @@ def put(self, entity):
183183
:type entity: :class:`google.cloud.datastore.entity.Entity`
184184
:param entity: the entity to be saved.
185185
186-
:raises: ValueError if entity has no key assigned, or if the key's
186+
:raises: :class:`~exceptions.ValueError` if the batch is not in
187+
progress, if entity has no key assigned, or if the key's
187188
``project`` does not match ours.
188189
"""
190+
if self._status != self._IN_PROGRESS:
191+
raise ValueError('Batch must be in progress to put()')
192+
189193
if entity.key is None:
190194
raise ValueError("Entity must have a key")
191195

@@ -206,9 +210,13 @@ def delete(self, key):
206210
:type key: :class:`google.cloud.datastore.key.Key`
207211
:param key: the key to be deleted.
208212
209-
:raises: ValueError if key is not complete, or if the key's
213+
:raises: :class:`~exceptions.ValueError` if the batch is not in
214+
progress, if key is not complete, or if the key's
210215
``project`` does not match ours.
211216
"""
217+
if self._status != self._IN_PROGRESS:
218+
raise ValueError('Batch must be in progress to delete()')
219+
212220
if key.is_partial:
213221
raise ValueError("Key must be complete")
214222

@@ -255,7 +263,13 @@ def commit(self):
255263
This is called automatically upon exiting a with statement,
256264
however it can be called explicitly if you don't want to use a
257265
context manager.
266+
267+
:raises: :class:`~exceptions.ValueError` if the batch is not
268+
in progress.
258269
"""
270+
if self._status != self._IN_PROGRESS:
271+
raise ValueError('Batch must be in progress to commit()')
272+
259273
try:
260274
self._commit()
261275
finally:
@@ -267,12 +281,19 @@ def rollback(self):
267281
Marks the batch as aborted (can't be used again).
268282
269283
Overridden by :class:`google.cloud.datastore.transaction.Transaction`.
284+
285+
:raises: :class:`~exceptions.ValueError` if the batch is not
286+
in progress.
270287
"""
288+
if self._status != self._IN_PROGRESS:
289+
raise ValueError('Batch must be in progress to rollback()')
290+
271291
self._status = self._ABORTED
272292

273293
def __enter__(self):
274-
self._client._push_batch(self)
275294
self.begin()
295+
# NOTE: We make sure begin() succeeds before pushing onto the stack.
296+
self._client._push_batch(self)
276297
return self
277298

278299
def __exit__(self, exc_type, exc_val, exc_tb):

google/cloud/datastore/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def put_multi(self, entities):
348348

349349
if not in_batch:
350350
current = self.batch()
351+
current.begin()
351352

352353
for entity in entities:
353354
current.put(entity)
@@ -384,6 +385,7 @@ def delete_multi(self, keys):
384385

385386
if not in_batch:
386387
current = self.batch()
388+
current.begin()
387389

388390
for key in keys:
389391
current.delete(key)

google/cloud/datastore/transaction.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ class Transaction(Batch):
9090
:param client: the client used to connect to datastore.
9191
"""
9292

93+
_status = None
94+
9395
def __init__(self, client):
9496
super(Transaction, self).__init__(client)
9597
self._id = None
@@ -125,10 +127,15 @@ def begin(self):
125127
statement, however it can be called explicitly if you don't want
126128
to use a context manager.
127129
128-
:raises: :class:`ValueError` if the transaction has already begun.
130+
:raises: :class:`~exceptions.ValueError` if the transaction has
131+
already begun.
129132
"""
130133
super(Transaction, self).begin()
131-
self._id = self.connection.begin_transaction(self.project)
134+
try:
135+
self._id = self.connection.begin_transaction(self.project)
136+
except:
137+
self._status = self._ABORTED
138+
raise
132139

133140
def rollback(self):
134141
"""Rolls back the current transaction.

unit_tests/datastore/test_batch.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,20 @@ def test_put_entity_wo_key(self):
6868
client = _Client(_PROJECT, connection)
6969
batch = self._makeOne(client)
7070

71+
batch.begin()
7172
self.assertRaises(ValueError, batch.put, _Entity())
7273

74+
def test_put_entity_wrong_status(self):
75+
_PROJECT = 'PROJECT'
76+
connection = _Connection()
77+
client = _Client(_PROJECT, connection)
78+
batch = self._makeOne(client)
79+
entity = _Entity()
80+
entity.key = _Key('OTHER')
81+
82+
self.assertEqual(batch._status, batch._INITIAL)
83+
self.assertRaises(ValueError, batch.put, entity)
84+
7385
def test_put_entity_w_key_wrong_project(self):
7486
_PROJECT = 'PROJECT'
7587
connection = _Connection()
@@ -78,6 +90,7 @@ def test_put_entity_w_key_wrong_project(self):
7890
entity = _Entity()
7991
entity.key = _Key('OTHER')
8092

93+
batch.begin()
8194
self.assertRaises(ValueError, batch.put, entity)
8295

8396
def test_put_entity_w_partial_key(self):
@@ -90,6 +103,7 @@ def test_put_entity_w_partial_key(self):
90103
key = entity.key = _Key(_PROJECT)
91104
key._id = None
92105

106+
batch.begin()
93107
batch.put(entity)
94108

95109
mutated_entity = _mutated_pb(self, batch.mutations, 'insert')
@@ -113,6 +127,7 @@ def test_put_entity_w_completed_key(self):
113127
entity.exclude_from_indexes = ('baz', 'spam')
114128
key = entity.key = _Key(_PROJECT)
115129

130+
batch.begin()
116131
batch.put(entity)
117132

118133
mutated_entity = _mutated_pb(self, batch.mutations, 'upsert')
@@ -129,6 +144,17 @@ def test_put_entity_w_completed_key(self):
129144
self.assertTrue(spam_values[2].exclude_from_indexes)
130145
self.assertFalse('frotz' in prop_dict)
131146

147+
def test_delete_wrong_status(self):
148+
_PROJECT = 'PROJECT'
149+
connection = _Connection()
150+
client = _Client(_PROJECT, connection)
151+
batch = self._makeOne(client)
152+
key = _Key(_PROJECT)
153+
key._id = None
154+
155+
self.assertEqual(batch._status, batch._INITIAL)
156+
self.assertRaises(ValueError, batch.delete, key)
157+
132158
def test_delete_w_partial_key(self):
133159
_PROJECT = 'PROJECT'
134160
connection = _Connection()
@@ -137,6 +163,7 @@ def test_delete_w_partial_key(self):
137163
key = _Key(_PROJECT)
138164
key._id = None
139165

166+
batch.begin()
140167
self.assertRaises(ValueError, batch.delete, key)
141168

142169
def test_delete_w_key_wrong_project(self):
@@ -146,6 +173,7 @@ def test_delete_w_key_wrong_project(self):
146173
batch = self._makeOne(client)
147174
key = _Key('OTHER')
148175

176+
batch.begin()
149177
self.assertRaises(ValueError, batch.delete, key)
150178

151179
def test_delete_w_completed_key(self):
@@ -155,6 +183,7 @@ def test_delete_w_completed_key(self):
155183
batch = self._makeOne(client)
156184
key = _Key(_PROJECT)
157185

186+
batch.begin()
158187
batch.delete(key)
159188

160189
mutated_key = _mutated_pb(self, batch.mutations, 'delete')
@@ -180,23 +209,43 @@ def test_rollback(self):
180209
_PROJECT = 'PROJECT'
181210
client = _Client(_PROJECT, None)
182211
batch = self._makeOne(client)
183-
self.assertEqual(batch._status, batch._INITIAL)
212+
batch.begin()
213+
self.assertEqual(batch._status, batch._IN_PROGRESS)
184214
batch.rollback()
185215
self.assertEqual(batch._status, batch._ABORTED)
186216

217+
def test_rollback_wrong_status(self):
218+
_PROJECT = 'PROJECT'
219+
client = _Client(_PROJECT, None)
220+
batch = self._makeOne(client)
221+
222+
self.assertEqual(batch._status, batch._INITIAL)
223+
self.assertRaises(ValueError, batch.rollback)
224+
187225
def test_commit(self):
188226
_PROJECT = 'PROJECT'
189227
connection = _Connection()
190228
client = _Client(_PROJECT, connection)
191229
batch = self._makeOne(client)
192230

193231
self.assertEqual(batch._status, batch._INITIAL)
232+
batch.begin()
233+
self.assertEqual(batch._status, batch._IN_PROGRESS)
194234
batch.commit()
195235
self.assertEqual(batch._status, batch._FINISHED)
196236

197237
self.assertEqual(connection._committed,
198238
[(_PROJECT, batch._commit_request, None)])
199239

240+
def test_commit_wrong_status(self):
241+
_PROJECT = 'PROJECT'
242+
connection = _Connection()
243+
client = _Client(_PROJECT, connection)
244+
batch = self._makeOne(client)
245+
246+
self.assertEqual(batch._status, batch._INITIAL)
247+
self.assertRaises(ValueError, batch.commit)
248+
200249
def test_commit_w_partial_key_entities(self):
201250
_PROJECT = 'PROJECT'
202251
_NEW_ID = 1234
@@ -209,6 +258,8 @@ def test_commit_w_partial_key_entities(self):
209258
batch._partial_key_entities.append(entity)
210259

211260
self.assertEqual(batch._status, batch._INITIAL)
261+
batch.begin()
262+
self.assertEqual(batch._status, batch._IN_PROGRESS)
212263
batch.commit()
213264
self.assertEqual(batch._status, batch._FINISHED)
214265

@@ -295,6 +346,26 @@ def test_as_context_mgr_w_error(self):
295346
self.assertEqual(mutated_entity.key, key._key)
296347
self.assertEqual(connection._committed, [])
297348

349+
def test_as_context_mgr_enter_fails(self):
350+
klass = self._getTargetClass()
351+
352+
class FailedBegin(klass):
353+
354+
def begin(self):
355+
raise RuntimeError
356+
357+
client = _Client(None, None)
358+
self.assertEqual(client._batches, [])
359+
360+
batch = FailedBegin(client)
361+
with self.assertRaises(RuntimeError):
362+
# The context manager will never be entered because
363+
# of the failure.
364+
with batch: # pragma: NO COVER
365+
pass
366+
# Make sure no batch was added.
367+
self.assertEqual(client._batches, [])
368+
298369

299370
class _PathElementPB(object):
300371

unit_tests/datastore/test_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def __init__(self, client):
960960
from google.cloud.datastore.batch import Batch
961961
self._client = client
962962
self._batch = Batch(client)
963+
self._batch.begin()
963964

964965
def __enter__(self):
965966
self._client._push_batch(self._batch)
@@ -972,10 +973,12 @@ def __exit__(self, *args):
972973
class _NoCommitTransaction(object):
973974

974975
def __init__(self, client, transaction_id='TRANSACTION'):
976+
from google.cloud.datastore.batch import Batch
975977
from google.cloud.datastore.transaction import Transaction
976978
self._client = client
977979
xact = self._transaction = Transaction(client)
978980
xact._id = transaction_id
981+
Batch.begin(xact)
979982

980983
def __enter__(self):
981984
self._client._push_batch(self._transaction)

unit_tests/datastore/test_transaction.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def test_begin_tombstoned(self):
8888

8989
self.assertRaises(ValueError, xact.begin)
9090

91+
def test_begin_w_begin_transaction_failure(self):
92+
_PROJECT = 'PROJECT'
93+
connection = _Connection(234)
94+
client = _Client(_PROJECT, connection)
95+
xact = self._makeOne(client)
96+
97+
connection._side_effect = RuntimeError
98+
with self.assertRaises(RuntimeError):
99+
xact.begin()
100+
101+
self.assertIsNone(xact.id)
102+
self.assertEqual(connection._begun, _PROJECT)
103+
91104
def test_rollback(self):
92105
_PROJECT = 'PROJECT'
93106
connection = _Connection(234)
@@ -118,10 +131,10 @@ def test_commit_w_partial_keys(self):
118131
connection._completed_keys = [_make_key(_KIND, _ID, _PROJECT)]
119132
client = _Client(_PROJECT, connection)
120133
xact = self._makeOne(client)
134+
xact.begin()
121135
entity = _Entity()
122136
xact.put(entity)
123137
xact._commit_request = commit_request = object()
124-
xact.begin()
125138
xact.commit()
126139
self.assertEqual(connection._committed,
127140
(_PROJECT, commit_request, 234))
@@ -176,7 +189,10 @@ def _make_key(kind, id_, project):
176189

177190
class _Connection(object):
178191
_marker = object()
179-
_begun = _rolled_back = _committed = None
192+
_begun = None
193+
_rolled_back = None
194+
_committed = None
195+
_side_effect = None
180196

181197
def __init__(self, xact_id=123):
182198
self._xact_id = xact_id
@@ -185,7 +201,10 @@ def __init__(self, xact_id=123):
185201

186202
def begin_transaction(self, project):
187203
self._begun = project
188-
return self._xact_id
204+
if self._side_effect is None:
205+
return self._xact_id
206+
else:
207+
raise self._side_effect
189208

190209
def rollback(self, project, transaction_id):
191210
self._rolled_back = project, transaction_id

0 commit comments

Comments
 (0)