|
16 | 16 | import asyncio |
17 | 17 | import json |
18 | 18 | import logging |
| 19 | +import os |
19 | 20 | import re |
20 | 21 | from typing import Any |
21 | 22 | from typing import Dict |
22 | 23 | from typing import Optional |
23 | 24 | import urllib.parse |
24 | 25 |
|
25 | 26 | from dateutil import parser |
| 27 | +from google.genai.errors import ClientError |
26 | 28 | from typing_extensions import override |
27 | 29 |
|
28 | 30 | from google import genai |
@@ -95,25 +97,46 @@ async def create_session( |
95 | 97 | operation_id = api_response['name'].split('/')[-1] |
96 | 98 |
|
97 | 99 | max_retry_attempt = 5 |
98 | | - lro_response = None |
99 | | - while max_retry_attempt >= 0: |
100 | | - lro_response = await api_client.async_request( |
101 | | - http_method='GET', |
102 | | - path=f'operations/{operation_id}', |
103 | | - request_dict={}, |
104 | | - ) |
105 | | - lro_response = _convert_api_response(lro_response) |
106 | 100 |
|
107 | | - if lro_response.get('done', None): |
108 | | - break |
109 | | - |
110 | | - await asyncio.sleep(1) |
111 | | - max_retry_attempt -= 1 |
112 | | - |
113 | | - if lro_response is None or not lro_response.get('done', None): |
114 | | - raise TimeoutError( |
115 | | - f'Timeout waiting for operation {operation_id} to complete.' |
116 | | - ) |
| 101 | + if _is_vertex_express_mode(self._project, self._location): |
| 102 | + # Express mode doesn't support LRO, so we need to poll |
| 103 | + # the session resource. |
| 104 | + # TODO: remove this once LRO polling is supported in Express mode. |
| 105 | + for i in range(max_retry_attempt): |
| 106 | + try: |
| 107 | + await api_client.async_request( |
| 108 | + http_method='GET', |
| 109 | + path=( |
| 110 | + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' |
| 111 | + ), |
| 112 | + request_dict={}, |
| 113 | + ) |
| 114 | + break |
| 115 | + except ClientError as e: |
| 116 | + logger.info('Polling for session %s: %s', session_id, e) |
| 117 | + # Add slight exponential backoff to avoid excessive polling. |
| 118 | + await asyncio.sleep(1 + 0.5 * i) |
| 119 | + else: |
| 120 | + raise TimeoutError('Session creation failed.') |
| 121 | + else: |
| 122 | + lro_response = None |
| 123 | + for _ in range(max_retry_attempt): |
| 124 | + lro_response = await api_client.async_request( |
| 125 | + http_method='GET', |
| 126 | + path=f'operations/{operation_id}', |
| 127 | + request_dict={}, |
| 128 | + ) |
| 129 | + lro_response = _convert_api_response(lro_response) |
| 130 | + |
| 131 | + if lro_response.get('done', None): |
| 132 | + break |
| 133 | + |
| 134 | + await asyncio.sleep(1) |
| 135 | + |
| 136 | + if lro_response is None or not lro_response.get('done', None): |
| 137 | + raise TimeoutError( |
| 138 | + f'Timeout waiting for operation {operation_id} to complete.' |
| 139 | + ) |
117 | 140 |
|
118 | 141 | # Get session resource |
119 | 142 | get_session_api_response = await api_client.async_request( |
@@ -312,6 +335,18 @@ def _get_api_client(self): |
312 | 335 | return client._api_client |
313 | 336 |
|
314 | 337 |
|
| 338 | +def _is_vertex_express_mode( |
| 339 | + project: Optional[str], location: Optional[str] |
| 340 | +) -> bool: |
| 341 | + """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" |
| 342 | + return ( |
| 343 | + os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in ['true', '1'] |
| 344 | + and os.environ.get('GOOGLE_API_KEY', None) is not None |
| 345 | + and project is None |
| 346 | + and location is None |
| 347 | + ) |
| 348 | + |
| 349 | + |
315 | 350 | def _convert_api_response(api_response): |
316 | 351 | """Converts the API response to a JSON object based on the type.""" |
317 | 352 | if hasattr(api_response, 'body'): |
|
0 commit comments