-
Notifications
You must be signed in to change notification settings - Fork 5.9k
fix: capture per-task token usage to fix race condition in async execution #4286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ | |
| from crewai.tasks.output_format import OutputFormat | ||
| from crewai.tasks.task_output import TaskOutput | ||
| from crewai.tools.base_tool import BaseTool | ||
| from crewai.types.usage_metrics import UsageMetrics | ||
| from crewai.utilities.config import process_config | ||
| from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified | ||
| from crewai.utilities.converter import Converter, convert_to_model | ||
|
|
@@ -527,16 +528,80 @@ def execute_async( | |
| ).start() | ||
| return future | ||
|
|
||
| @staticmethod | ||
| def _get_agent_token_usage(agent: BaseAgent | None) -> UsageMetrics: | ||
| """Get current token usage from an agent's LLM. | ||
|
|
||
| Captures a snapshot of the agent's LLM token usage at the current moment. | ||
| This is used to calculate per-task token deltas for accurate tracking | ||
| when multiple tasks run concurrently. | ||
|
|
||
| Args: | ||
| agent: The agent to get token usage from. | ||
|
|
||
| Returns: | ||
| UsageMetrics with current token counts, or empty metrics if unavailable. | ||
| """ | ||
| if agent is None: | ||
| return UsageMetrics() | ||
|
|
||
| # Try to get usage from the agent's LLM (BaseLLM instances) | ||
| if hasattr(agent, "llm") and agent.llm is not None: | ||
| from crewai.llms.base_llm import BaseLLM | ||
|
|
||
| if isinstance(agent.llm, BaseLLM): | ||
| return agent.llm.get_token_usage_summary() | ||
|
|
||
| # Fallback for litellm-based agents | ||
| if hasattr(agent, "_token_process"): | ||
| return agent._token_process.get_summary() | ||
|
|
||
| return UsageMetrics() | ||
|
|
||
| @staticmethod | ||
| def _calculate_token_delta(before: UsageMetrics, after: UsageMetrics) -> UsageMetrics: | ||
| """Calculate the token usage delta between two snapshots. | ||
|
|
||
| Args: | ||
| before: Token usage snapshot before task execution. | ||
| after: Token usage snapshot after task execution. | ||
|
|
||
| Returns: | ||
| UsageMetrics containing the difference (tokens used during task). | ||
| """ | ||
| return UsageMetrics( | ||
| total_tokens=after.total_tokens - before.total_tokens, | ||
| prompt_tokens=after.prompt_tokens - before.prompt_tokens, | ||
| cached_prompt_tokens=after.cached_prompt_tokens - before.cached_prompt_tokens, | ||
| completion_tokens=after.completion_tokens - before.completion_tokens, | ||
| successful_requests=after.successful_requests - before.successful_requests, | ||
| ) | ||
|
|
||
| def _execute_task_async( | ||
| self, | ||
| agent: BaseAgent | None, | ||
| context: str | None, | ||
| tools: list[Any] | None, | ||
| future: Future[TaskOutput], | ||
| ) -> None: | ||
| """Execute the task asynchronously with context handling.""" | ||
| """Execute the task asynchronously with context handling. | ||
|
|
||
| This method captures token usage before and after task execution within | ||
| the thread to ensure accurate per-task token tracking even when multiple | ||
| async tasks run concurrently. | ||
| """ | ||
| try: | ||
| # Capture token usage BEFORE execution within the thread | ||
| tokens_before = self._get_agent_token_usage(agent or self.agent) | ||
|
|
||
| result = self._execute_core(agent, context, tools) | ||
|
|
||
| # Capture token usage AFTER execution within the thread | ||
| tokens_after = self._get_agent_token_usage(agent or self.agent) | ||
|
|
||
| # Calculate and store the delta in the result | ||
| result.token_usage = self._calculate_token_delta(tokens_before, tokens_after) | ||
|
|
||
| future.set_result(result) | ||
| except Exception as e: | ||
| future.set_exception(e) | ||
|
|
@@ -568,6 +633,9 @@ async def _aexecute_core( | |
|
|
||
| self.start_time = datetime.datetime.now() | ||
|
|
||
| # Capture token usage before execution for accurate per-task tracking | ||
| tokens_before = self._get_agent_token_usage(agent) | ||
|
|
||
| self.prompt_context = context | ||
| tools = tools or self.tools or [] | ||
|
|
||
|
|
@@ -579,6 +647,10 @@ async def _aexecute_core( | |
| tools=tools, | ||
| ) | ||
|
|
||
| # Capture token usage after execution | ||
| tokens_after = self._get_agent_token_usage(agent) | ||
| token_delta = self._calculate_token_delta(tokens_before, tokens_after) | ||
|
|
||
| if not self._guardrails and not self._guardrail: | ||
| pydantic_output, json_output = self._export_output(result) | ||
| else: | ||
|
|
@@ -594,6 +666,7 @@ async def _aexecute_core( | |
| agent=agent.role, | ||
| output_format=self._get_output_format(), | ||
| messages=agent.last_messages, # type: ignore[attr-defined] | ||
| token_usage=token_delta, | ||
| ) | ||
|
|
||
| if self._guardrails: | ||
|
|
@@ -663,6 +736,9 @@ def _execute_core( | |
|
|
||
| self.start_time = datetime.datetime.now() | ||
|
|
||
| # Capture token usage before execution for accurate per-task tracking | ||
| tokens_before = self._get_agent_token_usage(agent) | ||
|
|
||
| self.prompt_context = context | ||
| tools = tools or self.tools or [] | ||
|
|
||
|
|
@@ -674,6 +750,10 @@ def _execute_core( | |
| tools=tools, | ||
| ) | ||
|
|
||
| # Capture token usage after execution | ||
| tokens_after = self._get_agent_token_usage(agent) | ||
| token_delta = self._calculate_token_delta(tokens_before, tokens_after) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Token capture misses output conversion LLM callsMedium Severity The Additional Locations (1) |
||
| if not self._guardrails and not self._guardrail: | ||
| pydantic_output, json_output = self._export_output(result) | ||
| else: | ||
|
|
@@ -689,6 +769,7 @@ def _execute_core( | |
| agent=agent.role, | ||
| output_format=self._get_output_format(), | ||
| messages=agent.last_messages, # type: ignore[attr-defined] | ||
| token_usage=token_delta, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Token usage lost when guardrails trigger task retryMedium Severity The Additional Locations (1) |
||
| ) | ||
|
|
||
| if self._guardrails: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant token tracking in async execution path
Low Severity
Token tracking happens twice when
_execute_task_asynccalls_execute_core. Both methods capturetokens_beforeandtokens_after, and both settoken_usage. Since_execute_corereturns aTaskOutputwithtoken_usagealready set, the subsequent overwrite in_execute_task_asyncis redundant. Both run in the same thread, so the "within thread" rationale doesn't justify the duplication.Additional Locations (1)
lib/crewai/src/crewai/task.py#L738-L755