Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pydgraph/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,7 @@ async def __aenter__(self) -> AsyncDgraphClient:
"""
return self

async def __aexit__(
self, exc_type: Any, exc_val: Any, exc_tb: Any
) -> bool:
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
"""Async context manager exit.

Automatically closes all client connections.
Expand Down
75 changes: 52 additions & 23 deletions pydgraph/async_txn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ async def do_request( # noqa: C901
query_error = error

if query_error is not None:
# Try to discard the transaction on error
# Try to discard the transaction on error.
# Note: We use _discard_internal() here because we already hold self._lock,
# and asyncio.Lock is not reentrant. Calling self.discard() would deadlock.
try:
await self.discard(
await self._discard_internal(
timeout=timeout, metadata=metadata, credentials=credentials
)
except asyncio.CancelledError:
Expand Down Expand Up @@ -458,32 +460,61 @@ async def discard(
Various gRPC errors on failure
"""
async with self._lock:
if not self._common_discard():
return
await self._discard_internal(
timeout=timeout, metadata=metadata, credentials=credentials
)

new_metadata = self._dg.add_login_metadata(metadata)
try:
async def _discard_internal(
self,
timeout: float | None = None,
metadata: list[tuple[str, str]] | None = None,
credentials: grpc.CallCredentials | None = None,
) -> None:
"""Internal discard implementation that doesn't acquire the lock.

This method must only be called when the caller already holds self._lock.
Use discard() for the public API.

Args:
timeout: Request timeout in seconds
metadata: Request metadata
credentials: Call credentials

Raises:
AssertionError: If called without holding self._lock
Various gRPC errors on failure
"""
# Defensive check: ensure caller holds the lock to prevent misuse
assert self._lock.locked(), (
"_discard_internal must only be called while holding self._lock"
)

if not self._common_discard():
return

new_metadata = self._dg.add_login_metadata(metadata)
try:
await self._dc.commit_or_abort(
self._ctx,
timeout=timeout,
metadata=new_metadata,
credentials=credentials,
)
except asyncio.CancelledError:
raise
except Exception as error:
# Handle JWT expiration with automatic retry
if util.is_jwt_expired(error):
await self._dg.retry_login()
new_metadata = self._dg.add_login_metadata(metadata)
await self._dc.commit_or_abort(
self._ctx,
timeout=timeout,
metadata=new_metadata,
credentials=credentials,
)
except asyncio.CancelledError:
else:
raise
except Exception as error:
# Handle JWT expiration with automatic retry
if util.is_jwt_expired(error):
await self._dg.retry_login()
new_metadata = self._dg.add_login_metadata(metadata)
await self._dc.commit_or_abort(
self._ctx,
timeout=timeout,
metadata=new_metadata,
credentials=credentials,
)
else:
raise

def _common_discard(self) -> bool:
"""Validates and prepares for discard.
Expand Down Expand Up @@ -533,9 +564,7 @@ async def __aenter__(self) -> AsyncTxn:
"""
return self

async def __aexit__(
self, exc_type: Any, exc_val: Any, exc_tb: Any
) -> bool:
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
"""Async context manager exit.

Automatically discards transaction if not already finished.
Expand Down
1 change: 1 addition & 0 deletions pydgraph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

"""Dgraph python client."""

from __future__ import annotations

import contextlib
Expand Down
1 change: 1 addition & 0 deletions pydgraph/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def upsert_user(client, name: str):
txn.mutate(set_obj={"name": name})
txn.commit()
"""

import asyncio
import functools
import logging
Expand Down
Loading