@@ -347,10 +347,10 @@ class NeMoGymChatCompletionMessageToolCallParam(ChatCompletionMessageToolCallPar
347347 function : NeMoGymChatCompletionMessageToolCallFunctionParam
348348
349349
350- class NeMoGymChatCompletionAssistantMessageParam (ChatCompletionAssistantMessageParam ):
350+ class NeMoGymChatCompletionAssistantMessageParam (ChatCompletionAssistantMessageParam , total = False ):
351351 # Override the iterable which is annoying to work with.
352352 content : Union [str , List [ContentArrayOfContentPart ], None ]
353- tool_calls : List [NeMoGymChatCompletionMessageToolCallParam ]
353+ tool_calls : Optional [ List [NeMoGymChatCompletionMessageToolCallParam ]] = None
354354
355355
356356class NeMoGymChatCompletionAssistantMessageForTrainingParam (
@@ -434,7 +434,10 @@ async def _request(self, **request_kwargs: Dict) -> ClientResponse:
434434 tries += 1
435435 response = await request (** request_kwargs )
436436 # See https://platform.openai.com/docs/guides/error-codes/api-errors
437- if response .status in (429 , 500 , 503 ):
437+ # 500 is internal server error, which may sporadically occur
438+ # 502 is Bad gateway (when the endpoint is overloaded)
439+ # 504 is Gateway timeout (when the endpoint config has too low of a gateway timeout setting for the model to finish generating)
440+ if response .status in (429 , 500 , 502 , 503 , 504 ):
438441 content = (await response .content .read ()).decode ()
439442 print (
440443 f"Hit a { response .status } trying to query an OpenAI endpoint (try { tries } ). Sleeping 0.5s. Error message: { content } "
0 commit comments