Skip to content

Commit b5713c8

Browse files
cguardiaandrewsg
andauthored
fix: make nested retry blocks work for RPC calls (#589)
* fix: make nested retry blocks work for RPC calls fixes #567 * fix: use special retry exception to return flow to outer retry block Co-authored-by: Andrew Gorcester <gorcester@google.com>
1 parent 1317569 commit b5713c8

File tree

5 files changed

+88
-2
lines changed

5 files changed

+88
-2
lines changed

packages/google-cloud-ndb/google/cloud/ndb/_retry.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from google.api_core import retry as core_retry
2121
from google.api_core import exceptions as core_exceptions
22+
from google.cloud.ndb import exceptions
2223
from google.cloud.ndb import tasklets
2324

2425
_DEFAULT_INITIAL_DELAY = 1.0 # seconds
@@ -59,24 +60,47 @@ def retry_async(callback, retries=_DEFAULT_RETRIES):
5960
@tasklets.tasklet
6061
@wraps_safely(callback)
6162
def retry_wrapper(*args, **kwargs):
63+
from google.cloud.ndb import context as context_module
64+
6265
sleep_generator = core_retry.exponential_sleep_generator(
6366
_DEFAULT_INITIAL_DELAY,
6467
_DEFAULT_MAXIMUM_DELAY,
6568
_DEFAULT_DELAY_MULTIPLIER,
6669
)
6770

6871
for sleep_time in itertools.islice(sleep_generator, retries + 1):
72+
context = context_module.get_context()
73+
if not context.in_retry():
74+
# We need to be able to identify if we are inside a nested
75+
# retry. Here, we set the retry state in the context. This is
76+
# used for deciding if an exception should be raised
77+
# immediately or passed up to the outer retry block.
78+
context.set_retry_state(repr(callback))
6979
try:
7080
result = callback(*args, **kwargs)
7181
if isinstance(result, tasklets.Future):
7282
result = yield result
83+
except exceptions.NestedRetryException as e:
84+
error = e
7385
except Exception as e:
7486
# `e` is removed from locals at end of block
7587
error = e # See: https://goo.gl/5J8BMK
7688
if not is_transient_error(error):
77-
raise error
89+
# If we are in an inner retry block, use special nested
90+
# retry exception to bubble up to outer retry. Else, raise
91+
# actual exception.
92+
if context.get_retry_state() != repr(callback):
93+
message = getattr(error, "message", str(error))
94+
raise exceptions.NestedRetryException(message)
95+
else:
96+
raise error
7897
else:
7998
raise tasklets.Return(result)
99+
finally:
100+
# No matter what, if we are exiting the top level retry,
101+
# clear the retry state in the context.
102+
if context.get_retry_state() == repr(callback): # pragma: NO BRANCH
103+
context.clear_retry_state()
80104

81105
yield tasklets.sleep(sleep_time)
82106

packages/google-cloud-ndb/google/cloud/ndb/_transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _transaction_async(context, callback, read_only=False):
270270
# new event loop is of the same type as the current one, to propagate
271271
# the event loop class used for testing.
272272
eventloop=type(context.eventloop)(),
273+
retry=context.get_retry_state(),
273274
)
274275

275276
# The outer loop is dependent on the inner loop

packages/google-cloud-ndb/google/cloud/ndb/context.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __new__(
247247
datastore_policy=None,
248248
on_commit_callbacks=None,
249249
legacy_data=True,
250+
retry=None,
250251
rpc_time=None,
251252
wait_time=None,
252253
):
@@ -286,6 +287,7 @@ def __new__(
286287
context.set_global_cache_policy(global_cache_policy)
287288
context.set_global_cache_timeout_policy(global_cache_timeout_policy)
288289
context.set_datastore_policy(datastore_policy)
290+
context.set_retry_state(retry)
289291

290292
return context
291293

@@ -296,7 +298,9 @@ def new(self, **kwargs):
296298
will be substituted.
297299
"""
298300
fields = self._fields + tuple(self.__dict__.keys())
299-
state = {name: getattr(self, name) for name in fields}
301+
state = {
302+
name: getattr(self, name) for name in fields if not name.startswith("_")
303+
}
300304
state.update(kwargs)
301305
return type(self)(**state)
302306

@@ -544,6 +548,15 @@ def policy(key):
544548

545549
set_memcache_timeout_policy = set_global_cache_timeout_policy
546550

551+
def get_retry_state(self):
552+
return self._retry
553+
554+
def set_retry_state(self, state):
555+
self._retry = state
556+
557+
def clear_retry_state(self):
558+
self._retry = None
559+
547560
def call_on_commit(self, callback):
548561
"""Call a callback upon successful commit of a transaction.
549562
@@ -578,6 +591,15 @@ def in_transaction(self):
578591
"""
579592
return self.transaction is not None
580593

594+
def in_retry(self):
595+
"""Get whether we are already in a retry block.
596+
597+
Returns:
598+
bool: :data:`True` if currently in a retry block, otherwise
599+
:data:`False`.
600+
"""
601+
return self._retry is not None
602+
581603
def memcache_add(self, *args, **kwargs):
582604
"""Direct pass-through to memcache client."""
583605
raise exceptions.NoLongerImplementedError()

packages/google-cloud-ndb/google/cloud/ndb/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,11 @@ class Cancelled(Error):
121121
a call to ``Future.cancel`` (possibly on a future that depends on this
122122
future).
123123
"""
124+
125+
126+
class NestedRetryException(Error):
127+
"""A nested retry block raised an exception.
128+
129+
Raised when a nested retry block cannot complete due to an exception. This
130+
allows the outer retry to get back control and retry the whole operation.
131+
"""

packages/google-cloud-ndb/tests/unit/test__retry.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,37 @@ def callback():
3838
retry = _retry.retry_async(callback)
3939
assert retry().result() == "foo"
4040

41+
@staticmethod
42+
@pytest.mark.usefixtures("in_context")
43+
def test_nested_retry():
44+
def callback():
45+
def nested_callback():
46+
return "bar"
47+
48+
nested = _retry.retry_async(nested_callback)
49+
assert nested().result() == "bar"
50+
51+
return "foo"
52+
53+
retry = _retry.retry_async(callback)
54+
assert retry().result() == "foo"
55+
56+
@staticmethod
57+
@pytest.mark.usefixtures("in_context")
58+
def test_nested_retry_with_exception():
59+
error = Exception("Fail")
60+
61+
def callback():
62+
def nested_callback():
63+
raise error
64+
65+
nested = _retry.retry_async(nested_callback, retries=1)
66+
return nested()
67+
68+
with pytest.raises(core_exceptions.RetryError):
69+
retry = _retry.retry_async(callback, retries=1)
70+
retry().result()
71+
4172
@staticmethod
4273
@pytest.mark.usefixtures("in_context")
4374
def test_success_callback_is_tasklet():

0 commit comments

Comments
 (0)