Skip to content

Commit a6625ec

Browse files
committed
fix: mini workflow fixes
1 parent 1a2104f commit a6625ec

14 files changed

Lines changed: 1261 additions & 136 deletions

File tree

application/agents/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
from application.agents.tools.tool_action_parser import ToolActionParser
99
from application.agents.tools.tool_manager import ToolManager
10+
from application.core.json_schema_utils import (
11+
JsonSchemaValidationError,
12+
normalize_json_schema_payload,
13+
)
1014
from application.core.mongo_db import MongoDB
1115
from application.core.settings import settings
1216
from application.llm.handlers.handler_creator import LLMHandlerCreator
@@ -63,7 +67,12 @@ def __init__(
6367
llm_name if llm_name else "default"
6468
)
6569
self.attachments = attachments or []
66-
self.json_schema = json_schema
70+
self.json_schema = None
71+
if json_schema is not None:
72+
try:
73+
self.json_schema = normalize_json_schema_payload(json_schema)
74+
except JsonSchemaValidationError as exc:
75+
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
6776
self.limited_token_mode = limited_token_mode
6877
self.token_limit = token_limit
6978
self.limited_request_mode = limited_request_mode

application/agents/workflow_agent.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,21 @@ def _determine_run_status(self) -> ExecutionStatus:
211211
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
212212
serialized: Dict[str, Any] = {}
213213
for key, value in state.items():
214-
if isinstance(value, (str, int, float, bool, type(None))):
215-
serialized[key] = value
216-
else:
217-
serialized[key] = str(value)
214+
serialized[key] = self._serialize_state_value(value)
218215
return serialized
216+
217+
def _serialize_state_value(self, value: Any) -> Any:
218+
if isinstance(value, dict):
219+
return {
220+
str(dict_key): self._serialize_state_value(dict_value)
221+
for dict_key, dict_value in value.items()
222+
}
223+
if isinstance(value, list):
224+
return [self._serialize_state_value(item) for item in value]
225+
if isinstance(value, tuple):
226+
return [self._serialize_state_value(item) for item in value]
227+
if isinstance(value, datetime):
228+
return value.isoformat()
229+
if isinstance(value, (str, int, float, bool, type(None))):
230+
return value
231+
return str(value)

application/agents/workflows/workflow_engine.py

Lines changed: 189 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
from datetime import datetime, timezone
34
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
@@ -13,13 +14,25 @@
1314
WorkflowGraph,
1415
WorkflowNode,
1516
)
17+
from application.core.json_schema_utils import (
18+
JsonSchemaValidationError,
19+
normalize_json_schema_payload,
20+
)
21+
from application.templates.namespaces import NamespaceManager
22+
from application.templates.template_engine import TemplateEngine, TemplateRenderError
23+
24+
try:
25+
import jsonschema
26+
except ImportError: # pragma: no cover - optional dependency in some deployments.
27+
jsonschema = None
1628

1729
if TYPE_CHECKING:
1830
from application.agents.base import BaseAgent
1931
logger = logging.getLogger(__name__)
2032

2133
StateValue = Any
2234
WorkflowState = Dict[str, StateValue]
35+
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
2336

2437

2538
class WorkflowEngine:
@@ -31,6 +44,8 @@ def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
3144
self.state: WorkflowState = {}
3245
self.execution_log: List[Dict[str, Any]] = []
3346
self._condition_result: Optional[str] = None
47+
self._template_engine = TemplateEngine()
48+
self._namespace_manager = NamespaceManager()
3449

3550
def execute(
3651
self, initial_inputs: WorkflowState, query: str
@@ -174,35 +189,62 @@ def _execute_note_node(
174189
def _execute_agent_node(
175190
self, node: WorkflowNode
176191
) -> Generator[Dict[str, str], None, None]:
177-
from application.core.model_utils import get_api_key_for_provider
192+
from application.core.model_utils import (
193+
get_api_key_for_provider,
194+
get_model_capabilities,
195+
get_provider_from_model_id,
196+
)
178197

179198
node_config = AgentNodeConfig(**node.config.get("config", node.config))
180199

181200
if node_config.prompt_template:
182201
formatted_prompt = self._format_template(node_config.prompt_template)
183202
else:
184203
formatted_prompt = self.state.get("query", "")
185-
node_llm_name = node_config.llm_name or self.agent.llm_name
204+
node_json_schema = self._normalize_node_json_schema(
205+
node_config.json_schema, node.title
206+
)
207+
node_model_id = node_config.model_id or self.agent.model_id
208+
node_llm_name = (
209+
node_config.llm_name
210+
or get_provider_from_model_id(node_model_id or "")
211+
or self.agent.llm_name
212+
)
186213
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
187214

215+
if node_json_schema and node_model_id:
216+
model_capabilities = get_model_capabilities(node_model_id)
217+
if model_capabilities and not model_capabilities.get(
218+
"supports_structured_output", False
219+
):
220+
raise ValueError(
221+
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
222+
)
223+
188224
node_agent = WorkflowNodeAgentFactory.create(
189225
agent_type=node_config.agent_type,
190226
endpoint=self.agent.endpoint,
191227
llm_name=node_llm_name,
192-
model_id=node_config.model_id or self.agent.model_id,
228+
model_id=node_model_id,
193229
api_key=node_api_key,
194230
tool_ids=node_config.tools,
195231
prompt=node_config.system_prompt,
196232
chat_history=self.agent.chat_history,
197233
decoded_token=self.agent.decoded_token,
198-
json_schema=node_config.json_schema,
234+
json_schema=node_json_schema,
199235
)
200236

201-
full_response = ""
237+
full_response_parts: List[str] = []
238+
structured_response_parts: List[str] = []
239+
has_structured_response = False
202240
first_chunk = True
203241
for event in node_agent.gen(formatted_prompt):
204242
if "answer" in event:
205-
full_response += event["answer"]
243+
chunk = str(event["answer"])
244+
full_response_parts.append(chunk)
245+
if event.get("structured"):
246+
has_structured_response = True
247+
structured_response_parts.append(chunk)
206248
if node_config.stream_to_user:
207249
if first_chunk and hasattr(self, "_has_streamed"):
208250
yield {"answer": "\n\n"}
@@ -212,8 +254,33 @@ def _execute_agent_node(
212254
if node_config.stream_to_user:
213255
self._has_streamed = True
214256

215-
output_key = node_config.output_variable or f"node_{node.id}_output"
216-
self.state[output_key] = full_response.strip()
257+
full_response = "".join(full_response_parts).strip()
258+
output_value: Any = full_response
259+
if has_structured_response:
260+
structured_response = "".join(structured_response_parts).strip()
261+
response_to_parse = structured_response or full_response
262+
parsed_success, parsed_structured = self._parse_structured_output(
263+
response_to_parse
264+
)
265+
output_value = parsed_structured if parsed_success else response_to_parse
266+
if node_json_schema:
267+
self._validate_structured_output(node_json_schema, output_value)
268+
elif node_json_schema:
269+
parsed_success, parsed_structured = self._parse_structured_output(
270+
full_response
271+
)
272+
if not parsed_success:
273+
raise ValueError(
274+
"Structured output was expected but response was not valid JSON"
275+
)
276+
output_value = parsed_structured
277+
self._validate_structured_output(node_json_schema, output_value)
278+
279+
default_output_key = f"node_{node.id}_output"
280+
self.state[default_output_key] = output_value
281+
282+
if node_config.output_variable:
283+
self.state[node_config.output_variable] = output_value
217284

218285
def _execute_state_node(
219286
self, node: WorkflowNode
@@ -254,13 +321,122 @@ def _execute_end_node(
254321
formatted_output = self._format_template(output_template)
255322
yield {"answer": formatted_output}
256323

324+
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
325+
normalized_response = raw_response.strip()
326+
if not normalized_response:
327+
return False, None
328+
329+
try:
330+
return True, json.loads(normalized_response)
331+
except json.JSONDecodeError:
332+
logger.warning(
333+
"Workflow agent returned structured output that was not valid JSON"
334+
)
335+
return False, None
336+
337+
def _normalize_node_json_schema(
338+
self, schema: Optional[Dict[str, Any]], node_title: str
339+
) -> Optional[Dict[str, Any]]:
340+
if schema is None:
341+
return None
342+
try:
343+
return normalize_json_schema_payload(schema)
344+
except JsonSchemaValidationError as exc:
345+
raise ValueError(
346+
f'Invalid JSON schema for node "{node_title}": {exc}'
347+
) from exc
348+
349+
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
350+
if jsonschema is None:
351+
logger.warning(
352+
"jsonschema package is not available, skipping structured output validation"
353+
)
354+
return
355+
356+
try:
357+
normalized_schema = normalize_json_schema_payload(schema)
358+
except JsonSchemaValidationError as exc:
359+
raise ValueError(f"Invalid JSON schema: {exc}") from exc
360+
361+
try:
362+
jsonschema.validate(instance=output_value, schema=normalized_schema)
363+
except jsonschema.exceptions.ValidationError as exc:
364+
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
365+
except jsonschema.exceptions.SchemaError as exc:
366+
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
367+
257368
def _format_template(self, template: str) -> str:
258-
formatted = template
369+
context = self._build_template_context()
370+
try:
371+
return self._template_engine.render(template, context)
372+
except TemplateRenderError as e:
373+
logger.warning(
374+
"Workflow template rendering failed, using raw template: %s", str(e)
375+
)
376+
return template
377+
378+
def _build_template_context(self) -> Dict[str, Any]:
379+
docs, docs_together = self._get_source_template_data()
380+
passthrough_data = (
381+
self.state.get("passthrough")
382+
if isinstance(self.state.get("passthrough"), dict)
383+
else None
384+
)
385+
tools_data = (
386+
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
387+
)
388+
389+
context = self._namespace_manager.build_context(
390+
user_id=getattr(self.agent, "user", None),
391+
request_id=getattr(self.agent, "request_id", None),
392+
passthrough_data=passthrough_data,
393+
docs=docs,
394+
docs_together=docs_together,
395+
tools_data=tools_data,
396+
)
397+
398+
agent_context: Dict[str, Any] = {}
259399
for key, value in self.state.items():
260-
placeholder = f"{{{{{key}}}}}"
261-
if placeholder in formatted and value is not None:
262-
formatted = formatted.replace(placeholder, str(value))
263-
return formatted
400+
if not isinstance(key, str):
401+
continue
402+
normalized_key = key.strip()
403+
if not normalized_key:
404+
continue
405+
agent_context[normalized_key] = value
406+
407+
context["agent"] = agent_context
408+
409+
# Keep legacy top-level variables working while namespaced variables are adopted.
410+
for key, value in agent_context.items():
411+
if key in TEMPLATE_RESERVED_NAMESPACES:
412+
context[f"agent_{key}"] = value
413+
continue
414+
if key not in context:
415+
context[key] = value
416+
417+
return context
418+
419+
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
420+
docs = getattr(self.agent, "retrieved_docs", None)
421+
if not isinstance(docs, list) or len(docs) == 0:
422+
return None, None
423+
424+
docs_together_parts: List[str] = []
425+
for doc in docs:
426+
if not isinstance(doc, dict):
427+
continue
428+
text = doc.get("text")
429+
if not isinstance(text, str):
430+
continue
431+
432+
filename = doc.get("filename") or doc.get("title") or doc.get("source")
433+
if isinstance(filename, str) and filename.strip():
434+
docs_together_parts.append(f"{filename}\n{text}")
435+
else:
436+
docs_together_parts.append(text)
437+
438+
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
439+
return docs, docs_together
264440

265441
def get_execution_summary(self) -> List[NodeExecutionLog]:
266442
return [

0 commit comments

Comments
 (0)