From 7926024569a0703a200bb153088e4ffd888e26c7 Mon Sep 17 00:00:00 2001 From: Stream Date: Fri, 30 Jan 2026 06:46:38 +0800 Subject: [PATCH] feat: support structured output in sandbox and tool mode Signed-off-by: Stream --- api/core/agent/agent_app_runner.py | 48 ++++-- api/core/agent/base_agent_runner.py | 10 ++ api/core/agent/entities.py | 40 ++++- .../agent/output_parser/cot_output_parser.py | 46 +++--- api/core/agent/output_tools.py | 57 +++++++ api/core/agent/patterns/function_call.py | 119 ++++++++++++--- api/core/agent/patterns/react.py | 142 +++++++++++++----- api/core/agent/prompt/template.py | 20 ++- .../providers/agent_output/_assets/icon.svg | 5 + .../providers/agent_output/agent_output.py | 8 + .../providers/agent_output/agent_output.yaml | 10 ++ .../agent_output/tools/final_output_answer.py | 17 +++ .../tools/final_output_answer.yaml | 18 +++ .../tools/final_structured_output.py | 17 +++ .../tools/final_structured_output.yaml | 18 +++ .../agent_output/tools/illegal_output.py | 21 +++ .../agent_output/tools/illegal_output.yaml | 18 +++ .../agent_output/tools/output_text.py | 17 +++ .../agent_output/tools/output_text.yaml | 18 +++ api/core/workflow/nodes/llm/entities.py | 6 +- api/core/workflow/nodes/llm/node.py | 78 +++++++++- .../core/agent/test_agent_app_runner.py | 13 +- 22 files changed, 629 insertions(+), 117 deletions(-) create mode 100644 api/core/agent/output_tools.py create mode 100644 api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg create mode 100644 api/core/tools/builtin_tool/providers/agent_output/agent_output.py create mode 100644 api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py create mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py index 2ee0a23aab..3a0d96a1f9 100644 --- a/api/core/agent/agent_app_runner.py +++ b/api/core/agent/agent_app_runner.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Any from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentEntity, AgentLog, AgentResult +from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult from core.agent.patterns.strategy_factory import StrategyFactory from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent @@ -13,6 +13,7 @@ from core.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, + LLMResultChunkDelta, LLMUsage, PromptMessage, PromptMessageContentType, @@ -106,7 +107,6 @@ class AgentAppRunner(BaseAgentRunner): # Initialize state variables current_agent_thought_id = None - has_published_thought = False current_tool_name: str | None = None self._current_message_file_ids: list[str] = [] @@ -118,7 +118,7 @@ class AgentAppRunner(BaseAgentRunner): prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, stop=app_generate_entity.model_conf.stop, - stream=True, + stream=False, ) # Consume generator and collect result @@ -133,17 +133,10 @@ class AgentAppRunner(BaseAgentRunner): break if isinstance(output, LLMResultChunk): - # Handle LLM chunk - if current_agent_thought_id and not has_published_thought: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), - PublishFrom.APPLICATION_MANAGER, - ) - has_published_thought = True + # No more expect streaming data + continue - yield output - - elif isinstance(output, AgentLog): + else: # Handle Agent Log using log_type for type-safe dispatch if output.status == AgentLog.LogStatus.START: if output.log_type == AgentLog.LogType.ROUND: @@ -156,7 +149,6 @@ class AgentAppRunner(BaseAgentRunner): tool_input="", messages_ids=message_file_ids, ) - has_published_thought = False elif output.log_type == AgentLog.LogType.TOOL_CALL: if current_agent_thought_id is None: @@ -265,7 +257,22 @@ class AgentAppRunner(BaseAgentRunner): # Process final result if isinstance(result, AgentResult): - final_answer = result.text + output_payload = result.output + if isinstance(output_payload, AgentResult.StructuredOutput): + if output_payload.output_kind == AgentOutputKind.ILLEGAL_OUTPUT: + raise ValueError("Agent returned illegal output") + if output_payload.output_kind not in { + AgentOutputKind.FINAL_OUTPUT_ANSWER, + AgentOutputKind.OUTPUT_TEXT, + }: + raise ValueError("Agent did not return text output") + if not output_payload.output_text: + raise ValueError("Agent returned empty text output") + final_answer = output_payload.output_text + else: + if not output_payload: + raise ValueError("Agent returned empty output") + final_answer = str(output_payload) usage = result.usage or LLMUsage.empty_usage() # Publish end event @@ -282,6 +289,17 @@ class AgentAppRunner(BaseAgentRunner): PublishFrom.APPLICATION_MANAGER, ) + if False: + yield LLMResultChunk( + model="", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + usage=None, + ), + ) + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Initialize system message diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index b5459611b1..e5da2b3e12 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -7,6 +7,7 @@ from typing import Union, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext +from core.agent.output_tools import build_agent_output_tools from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -36,6 +37,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( + ToolInvokeFrom, ToolParameter, ) from core.tools.tool_manager import ToolManager @@ -251,6 +253,14 @@ class BaseAgentRunner(AppRunner): # save tool entity tool_instances[dataset_tool.entity.identity.name] = dataset_tool + output_tools = build_agent_output_tools( + tenant_id=self.tenant_id, + invoke_from=self.application_generate_entity.invoke_from, + tool_invoke_from=ToolInvokeFrom.AGENT, + ) + for tool in output_tools: + tool_instances[tool.entity.identity.name] = tool + return tool_instances, prompt_messages_tools def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 46af4d2d72..b606bb9048 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,10 +1,11 @@ import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Any, Union +from typing import Any from pydantic import BaseModel, Field +from core.agent.output_tools import FINAL_OUTPUT_TOOL from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType @@ -41,9 +42,9 @@ class AgentScratchpadUnit(BaseModel): """ action_name: str - action_input: Union[dict, str] + action_input: dict[str, Any] | str - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """ Convert to dictionary. """ @@ -62,9 +63,9 @@ class AgentScratchpadUnit(BaseModel): """ Check if the scratchpad unit is final. """ - return self.action is None or ( - "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() - ) + if self.action is None: + return False + return self.action.action_name.lower() == FINAL_OUTPUT_TOOL class AgentEntity(BaseModel): @@ -125,7 +126,7 @@ class ExecutionContext(BaseModel): "tenant_id": self.tenant_id, } - def with_updates(self, **kwargs) -> "ExecutionContext": + def with_updates(self, **kwargs: Any) -> "ExecutionContext": """Create a new context with updated fields.""" data = self.to_dict() data.update(kwargs) @@ -178,12 +179,35 @@ class AgentLog(BaseModel): metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log") +class AgentOutputKind(StrEnum): + """ + Agent output kind. + """ + + OUTPUT_TEXT = "output_text" + FINAL_OUTPUT_ANSWER = "final_output_answer" + FINAL_STRUCTURED_OUTPUT = "final_structured_output" + ILLEGAL_OUTPUT = "illegal_output" + + +OutputKind = AgentOutputKind + + class AgentResult(BaseModel): """ Agent execution result. """ - text: str = Field(default="", description="The generated text") + class StructuredOutput(BaseModel): + """ + Structured output payload from output tools. + """ + + output_kind: AgentOutputKind + output_text: str | None = None + output_data: Mapping[str, Any] | None = None + + output: str | StructuredOutput = Field(default="", description="The generated output") files: list[Any] = Field(default_factory=list, description="Files produced during execution") usage: Any | None = Field(default=None, description="LLM usage statistics") finish_reason: str | None = Field(default=None, description="Reason for completion") diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 7c8f09e6b9..2d4ce58b6a 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -1,7 +1,7 @@ import json import re from collections.abc import Generator -from typing import Union +from typing import Any, Union, cast from core.agent.entities import AgentScratchpadUnit from core.model_runtime.entities.llm_entities import LLMResultChunk @@ -10,46 +10,52 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod def handle_react_stream_output( - cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any] ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: - def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: - action_name = None - action_input = None - if isinstance(action, str): + def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]: + action_name: str | None = None + action_input: Any | None = None + parsed_action: Any = action + if isinstance(parsed_action, str): try: - action = json.loads(action, strict=False) + parsed_action = json.loads(parsed_action, strict=False) except json.JSONDecodeError: - return action or "" + return parsed_action or "" # cohere always returns a list - if isinstance(action, list) and len(action) == 1: - action = action[0] + if isinstance(parsed_action, list): + action_list: list[Any] = cast(list[Any], parsed_action) + if len(action_list) == 1: + parsed_action = action_list[0] - for key, value in action.items(): - if "input" in key.lower(): - action_input = value - else: - action_name = value + if isinstance(parsed_action, dict): + action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action) + for key, value in action_dict.items(): + if "input" in key.lower(): + action_input = value + elif isinstance(value, str): + action_name = value + else: + return json.dumps(parsed_action) if action_name is not None and action_input is not None: return AgentScratchpadUnit.Action( action_name=action_name, action_input=action_input, ) - else: - return json.dumps(action) + return json.dumps(parsed_action) - def extra_json_from_code_block(code_block) -> list[Union[list, dict]]: + def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]: blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE) if not blocks: return [] try: - json_blocks = [] + json_blocks: list[dict[str, Any] | list[Any]] = [] for block in blocks: json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) json_blocks.append(json.loads(json_text, strict=False)) return json_blocks - except: + except Exception: return [] code_block_cache = "" diff --git a/api/core/agent/output_tools.py b/api/core/agent/output_tools.py new file mode 100644 index 0000000000..0e8d18cc04 --- /dev/null +++ b/api/core/agent/output_tools.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.tool_manager import ToolManager + +OUTPUT_TOOL_PROVIDER = "agent_output" + +OUTPUT_TEXT_TOOL = "output_text" +FINAL_OUTPUT_TOOL = "final_output_answer" +FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output" +ILLEGAL_OUTPUT_TOOL = "illegal_output" + +OUTPUT_TOOL_NAMES: Sequence[str] = ( + OUTPUT_TEXT_TOOL, + FINAL_OUTPUT_TOOL, + FINAL_STRUCTURED_OUTPUT_TOOL, + ILLEGAL_OUTPUT_TOOL, +) + +OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES) + + +def build_agent_output_tools( + *, + tenant_id: str, + invoke_from: InvokeFrom, + tool_invoke_from: ToolInvokeFrom, + structured_output_schema: dict[str, Any] | None = None, +) -> list[Tool]: + tools: list[Tool] = [] + for tool_name in OUTPUT_TOOL_NAMES: + tool = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id=OUTPUT_TOOL_PROVIDER, + tool_name=tool_name, + tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + + if tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and structured_output_schema: + tool.entity = tool.entity.model_copy(deep=True) + for parameter in tool.entity.parameters: + if parameter.name != "data": + continue + parameter.type = ToolParameter.ToolParameterType.OBJECT + parameter.form = ToolParameter.ToolParameterForm.LLM + parameter.required = True + parameter.input_schema = structured_output_schema + tools.append(tool) + + return tools diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index a0f0d30ab3..f29e86b3f4 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -1,10 +1,18 @@ """Function Call strategy implementation.""" import json +import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Protocol, Union, cast -from core.agent.entities import AgentLog, AgentResult +from core.agent.entities import AgentLog, AgentOutputKind, AgentResult +from core.agent.output_tools import ( + FINAL_OUTPUT_TOOL, + FINAL_STRUCTURED_OUTPUT_TOOL, + ILLEGAL_OUTPUT_TOOL, + OUTPUT_TEXT_TOOL, + OUTPUT_TOOL_NAME_SET, +) from core.file import File from core.model_runtime.entities import ( AssistantPromptMessage, @@ -42,9 +50,24 @@ class FunctionCallStrategy(AgentPattern): total_usage: dict[str, LLMUsage | None] = {"usage": None} messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy final_text: str = "" + structured_output_payload: dict[str, Any] | None = None + output_text_payload: str | None = None finish_reason: str | None = None output_files: list[File] = [] # Track files produced by tools + class _LLMInvoker(Protocol): + def invoke_llm( + self, + *, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + stream: Literal[False], + user: str | None, + callbacks: list[Any], + ) -> LLMResult: ... + while function_call_state and iteration_step <= max_iterations: function_call_state = False round_log = self._create_log( @@ -54,8 +77,11 @@ class FunctionCallStrategy(AgentPattern): data={}, ) yield round_log - # On last iteration, remove tools to force final answer - current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools + # On last iteration, restrict tools to output tools + if iteration_step == max_iterations: + current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET] + else: + current_tools = prompt_tools model_log = self._create_log( label=f"{self.model_instance.model} Thought", log_type=AgentLog.LogType.THOUGHT, @@ -72,31 +98,41 @@ class FunctionCallStrategy(AgentPattern): round_usage: dict[str, LLMUsage | None] = {"usage": None} # Invoke model - chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + invoker = cast(_LLMInvoker, self.model_instance) + chunks = invoker.invoke_llm( prompt_messages=messages, model_parameters=model_parameters, tools=current_tools, stop=stop, - stream=stream, + stream=False, user=self.context.user_id, callbacks=[], ) # Process response tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks( - chunks, round_usage, model_log + chunks, round_usage, model_log, emit_chunks=False ) - messages.append(self._create_assistant_message(response_content, tool_calls)) + + if not tool_calls: + tool_calls = [ + ( + str(uuid.uuid4()), + ILLEGAL_OUTPUT_TOOL, + { + "raw": response_content, + }, + ) + ] + response_content = "" + + messages.append(self._create_assistant_message("", tool_calls)) # Accumulate to total usage round_usage_value = round_usage.get("usage") if round_usage_value: self._accumulate_usage(total_usage, round_usage_value) - # Update final text if no tool calls (this is likely the final answer) - if not tool_calls: - final_text = response_content - # Update finish reason if chunk_finish_reason: finish_reason = chunk_finish_reason @@ -105,14 +141,27 @@ class FunctionCallStrategy(AgentPattern): tool_outputs: dict[str, str] = {} if tool_calls: function_call_state = True + terminal_tool_seen = False # Execute tools for tool_call_id, tool_name, tool_args in tool_calls: + if tool_name == OUTPUT_TEXT_TOOL: + output_text_payload = self._format_output_text(tool_args.get("text")) + elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL: + data = tool_args.get("data") + structured_output_payload = cast(dict[str, Any] | None, data) + elif tool_name == FINAL_OUTPUT_TOOL: + final_text = self._format_output_text(tool_args.get("text")) + terminal_tool_seen = True + tool_response, tool_files, _ = yield from self._handle_tool_call( tool_name, tool_args, tool_call_id, messages, round_log ) tool_outputs[tool_name] = tool_response # Track files produced by tools output_files.extend(tool_files) + + if terminal_tool_seen: + function_call_state = False yield self._finish_log( round_log, data={ @@ -131,8 +180,28 @@ class FunctionCallStrategy(AgentPattern): # Return final result from core.agent.entities import AgentResult + output_payload: str | AgentResult.StructuredOutput + if final_text: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, + output_text=final_text, + output_data=structured_output_payload, + ) + elif output_text_payload: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.OUTPUT_TEXT, + output_text=str(output_text_payload), + output_data=None, + ) + else: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.ILLEGAL_OUTPUT, + output_text="Model failed to produce a final output.", + output_data=None, + ) + return AgentResult( - text=final_text, + output=output_payload, files=output_files, usage=total_usage.get("usage") or LLMUsage.empty_usage(), finish_reason=finish_reason, @@ -143,6 +212,8 @@ class FunctionCallStrategy(AgentPattern): chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], llm_usage: dict[str, LLMUsage | None], start_log: AgentLog, + *, + emit_chunks: bool, ) -> Generator[ LLMResultChunk | AgentLog, None, @@ -174,7 +245,8 @@ class FunctionCallStrategy(AgentPattern): if chunk.delta.finish_reason: finish_reason = chunk.delta.finish_reason - yield chunk + if emit_chunks: + yield chunk else: # Non-streaming response result: LLMResult = chunks @@ -189,11 +261,12 @@ class FunctionCallStrategy(AgentPattern): self._accumulate_usage(llm_usage, result.usage) # Convert to streaming format - yield LLMResultChunk( - model=result.model, - prompt_messages=result.prompt_messages, - delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), - ) + if emit_chunks: + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), + ) yield self._finish_log( start_log, data={ @@ -203,6 +276,14 @@ class FunctionCallStrategy(AgentPattern): ) return tool_calls, response_content, finish_reason + @staticmethod + def _format_output_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False) + def _create_assistant_message( self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None ) -> AssistantPromptMessage: diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py index 87a9fa9b65..207a824742 100644 --- a/api/core/agent/patterns/react.py +++ b/api/core/agent/patterns/react.py @@ -4,10 +4,17 @@ from __future__ import annotations import json from collections.abc import Generator -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Union, cast -from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext +from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext from core.agent.output_parser.cot_output_parser import CotAgentOutputParser +from core.agent.output_tools import ( + FINAL_OUTPUT_TOOL, + FINAL_STRUCTURED_OUTPUT_TOOL, + ILLEGAL_OUTPUT_TOOL, + OUTPUT_TEXT_TOOL, + OUTPUT_TOOL_NAME_SET, +) from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities import ( @@ -67,6 +74,8 @@ class ReActStrategy(AgentPattern): total_usage: dict[str, Any] = {"usage": None} output_files: list[File] = [] # Track files produced by tools final_text: str = "" + structured_output_payload: dict[str, Any] | None = None + output_text_payload: str | None = None finish_reason: str | None = None # Add "Observation" to stop sequences @@ -84,10 +93,13 @@ class ReActStrategy(AgentPattern): ) yield round_log - # Build prompt with/without tools based on iteration - include_tools = iteration_step < max_iterations + # Build prompt with tool restrictions on last iteration + if iteration_step == max_iterations: + tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET] + else: + tools_for_prompt = self.tools current_messages = self._build_prompt_with_react_format( - prompt_messages, agent_scratchpad, include_tools, self.instruction + prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction ) model_log = self._create_log( @@ -109,18 +121,21 @@ class ReActStrategy(AgentPattern): messages_to_use = current_messages # Invoke model - chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( - prompt_messages=messages_to_use, - model_parameters=model_parameters, - stop=stop, - stream=stream, - user=self.context.user_id or "", - callbacks=[], + chunks = cast( + Union[Generator[LLMResultChunk, None, None], LLMResult], + self.model_instance.invoke_llm( + prompt_messages=messages_to_use, + model_parameters=model_parameters, + stop=stop, + stream=False, + user=self.context.user_id or "", + callbacks=[], + ), ) # Process response scratchpad, chunk_finish_reason = yield from self._handle_chunks( - chunks, round_usage, model_log, current_messages + chunks, round_usage, model_log, current_messages, emit_chunks=False ) agent_scratchpad.append(scratchpad) @@ -134,28 +149,46 @@ class ReActStrategy(AgentPattern): finish_reason = chunk_finish_reason # Check if we have an action to execute - if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": - react_state = True - # Execute tool - observation, tool_files = yield from self._handle_tool_call( - scratchpad.action, current_messages, round_log + if scratchpad.action is None: + illegal_action = AgentScratchpadUnit.Action( + action_name=ILLEGAL_OUTPUT_TOOL, + action_input={"raw": scratchpad.thought or ""}, ) + scratchpad.action = illegal_action + scratchpad.action_str = json.dumps(illegal_action.to_dict()) + react_state = True + observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log) scratchpad.observation = observation - # Track files produced by tools output_files.extend(tool_files) - - # Add observation to scratchpad for display - yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages) else: - # Extract final answer - if scratchpad.action and scratchpad.action.action_input: - final_answer = scratchpad.action.action_input - if isinstance(final_answer, dict): - final_answer = json.dumps(final_answer, ensure_ascii=False) - final_text = str(final_answer) - elif scratchpad.thought: - # If no action but we have thought, use thought as final answer - final_text = scratchpad.thought + action_name = scratchpad.action.action_name + if action_name == FINAL_OUTPUT_TOOL: + if isinstance(scratchpad.action.action_input, dict): + final_text = self._format_output_text(scratchpad.action.action_input.get("text")) + else: + final_text = self._format_output_text(scratchpad.action.action_input) + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + output_files.extend(tool_files) + react_state = False + else: + if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict): + output_text_payload = scratchpad.action.action_input.get("text") + elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance( + scratchpad.action.action_input, dict + ): + data = scratchpad.action.action_input.get("data") + if isinstance(data, dict): + structured_output_payload = data + + react_state = True + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + output_files.extend(tool_files) yield self._finish_log( round_log, @@ -173,15 +206,38 @@ class ReActStrategy(AgentPattern): from core.agent.entities import AgentResult + output_payload: str | AgentResult.StructuredOutput + if final_text: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, + output_text=final_text, + output_data=structured_output_payload, + ) + elif output_text_payload: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.OUTPUT_TEXT, + output_text=str(output_text_payload), + output_data=structured_output_payload, + ) + else: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.ILLEGAL_OUTPUT, + output_text="Model failed to produce a final output.", + output_data=structured_output_payload, + ) + return AgentResult( - text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason + output=output_payload, + files=output_files, + usage=total_usage.get("usage"), + finish_reason=finish_reason, ) def _build_prompt_with_react_format( self, original_messages: list[PromptMessage], agent_scratchpad: list[AgentScratchpadUnit], - include_tools: bool = True, + tools: list[Tool] | None, instruction: str = "", ) -> list[PromptMessage]: """Build prompt messages with ReAct format.""" @@ -198,9 +254,9 @@ class ReActStrategy(AgentPattern): # Format tools tools_str = "" tool_names = [] - if include_tools and self.tools: + if tools: # Convert tools to prompt message tools format - prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools] + prompt_tools = [tool.to_prompt_message_tool() for tool in tools] tool_names = [tool.name for tool in prompt_tools] # Format tools as JSON for comprehensive information @@ -253,6 +309,8 @@ class ReActStrategy(AgentPattern): llm_usage: dict[str, Any], model_log: AgentLog, current_messages: list[PromptMessage], + *, + emit_chunks: bool, ) -> Generator[ LLMResultChunk | AgentLog, None, @@ -306,14 +364,16 @@ class ReActStrategy(AgentPattern): scratchpad.action_str = action_str scratchpad.action = chunk - yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) + if emit_chunks: + yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) else: # Text chunk chunk_text = str(chunk) scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text scratchpad.thought = (scratchpad.thought or "") + chunk_text - yield self._create_text_chunk(chunk_text, current_messages) + if emit_chunks: + yield self._create_text_chunk(chunk_text, current_messages) # Update usage if usage_dict.get("usage"): @@ -337,6 +397,14 @@ class ReActStrategy(AgentPattern): return scratchpad, finish_reason + @staticmethod + def _format_output_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False) + def _handle_tool_call( self, action: AgentScratchpadUnit.Action, diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index f5ba2119f4..e46d504e30 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -7,7 +7,7 @@ You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: "Final Answer" or {{tool_names}} +Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish. Provide only ONE action per $JSON_BLOB, as shown: @@ -32,12 +32,14 @@ Thought: I know what to respond Action: ``` { - "action": "Final Answer", - "action_input": "Final response to human" + "action": "final_output_answer", + "action_input": { + "text": "Final response to human" + } } ``` -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:. {{historic_messages}} Question: {{query}} {{agent_scratchpad}} @@ -56,7 +58,7 @@ You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: "Final Answer" or {{tool_names}} +Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish. Provide only ONE action per $JSON_BLOB, as shown: @@ -81,12 +83,14 @@ Thought: I know what to respond Action: ``` { - "action": "Final Answer", - "action_input": "Final response to human" + "action": "final_output_answer", + "action_input": { + "text": "Final response to human" + } } ``` -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:. """ # noqa: E501 diff --git a/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg b/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg new file mode 100644 index 0000000000..82e3b54edd --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/api/core/tools/builtin_tool/providers/agent_output/agent_output.py b/api/core/tools/builtin_tool/providers/agent_output/agent_output.py new file mode 100644 index 0000000000..65c3eb5a1d --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/agent_output.py @@ -0,0 +1,8 @@ +from typing import Any + +from core.tools.builtin_tool.provider import BuiltinToolProviderController + + +class AgentOutputProvider(BuiltinToolProviderController): + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + pass diff --git a/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml new file mode 100644 index 0000000000..1e4f624104 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml @@ -0,0 +1,10 @@ +identity: + author: Dify + name: agent_output + label: + en_US: Agent Output + description: + en_US: Internal tools for agent output control. + icon: icon.svg + tags: + - utilities diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py new file mode 100644 index 0000000000..ed5edc9fd3 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py @@ -0,0 +1,17 @@ +from collections.abc import Generator +from typing import Any + +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class FinalOutputAnswerTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("Final answer recorded.") diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml new file mode 100644 index 0000000000..b8e4f14d85 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml @@ -0,0 +1,18 @@ +identity: + name: final_output_answer + author: Dify + label: + en_US: Final Output Answer +description: + human: + en_US: Internal tool to deliver the final answer. + llm: Use this tool when you are ready to provide the final answer. +parameters: + - name: text + type: string + required: true + label: + en_US: Text + human_description: + en_US: Final answer text. + form: llm diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py new file mode 100644 index 0000000000..916dc9f0bc --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py @@ -0,0 +1,17 @@ +from collections.abc import Generator +from typing import Any + +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class FinalStructuredOutputTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("Structured output recorded.") diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml new file mode 100644 index 0000000000..79eb2d6412 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml @@ -0,0 +1,18 @@ +identity: + name: final_structured_output + author: Dify + label: + en_US: Final Structured Output +description: + human: + en_US: Internal tool to deliver structured output. + llm: Use this tool to provide structured output data. +parameters: + - name: data + type: object + required: true + label: + en_US: Data + human_description: + en_US: Structured output data. + form: llm diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py new file mode 100644 index 0000000000..27301d61b3 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py @@ -0,0 +1,21 @@ +from collections.abc import Generator +from typing import Any + +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class IllegalOutputTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + message = ( + "Protocol violation: do not output plain text. " + "Call output_text, final_structured_output, then final_output_answer." + ) + yield self.create_text_message(message) diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml new file mode 100644 index 0000000000..b293a41659 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml @@ -0,0 +1,18 @@ +identity: + name: illegal_output + author: Dify + label: + en_US: Illegal Output +description: + human: + en_US: Internal tool for output protocol violations. + llm: Use this tool to correct output protocol violations. +parameters: + - name: raw + type: string + required: true + label: + en_US: Raw Output + human_description: + en_US: Raw model output that violated the protocol. + form: llm diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py new file mode 100644 index 0000000000..942846f507 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py @@ -0,0 +1,17 @@ +from collections.abc import Generator +from typing import Any + +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class OutputTextTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("Output recorded.") diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml new file mode 100644 index 0000000000..48d237e521 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml @@ -0,0 +1,18 @@ +identity: + name: output_text + author: Dify + label: + en_US: Output Text +description: + human: + en_US: Internal tool to store intermediate text output. + llm: Use this tool to emit non-final text output. +parameters: + - name: text + type: string + required: true + label: + en_US: Text + human_description: + en_US: Output text. + form: llm diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index e982ceccf2..b0eacd00ea 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid from core.agent.entities import AgentLog, AgentResult from core.file import File from core.model_runtime.entities import ImagePromptMessageContent, LLMMode -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType from core.workflow.entities import ToolCall, ToolCallResult @@ -156,6 +156,9 @@ class LLMGenerationData(BaseModel): finish_reason: str | None = Field(None, description="Finish reason from LLM") files: list[File] = Field(default_factory=list, description="Generated files") trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") + structured_output: LLMStructuredOutput | None = Field( + default=None, description="Structured output from tool-only agent runs" + ) class ThinkTagStreamParser: @@ -284,6 +287,7 @@ class AggregatedResult(BaseModel): files: list[File] = Field(default_factory=list) usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) finish_reason: str | None = None + structured_output: LLMStructuredOutput | None = None class AgentContext(BaseModel): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index feb76e6510..8081055dad 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -12,7 +12,8 @@ from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select -from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext +from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext +from core.agent.output_tools import build_agent_output_tools from core.agent.patterns import StrategyFactory from core.app.entities.app_asset_entities import AppAssetFileTree from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -20,6 +21,7 @@ from core.app_assets.constants import AppAssetsAttrs from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.memory.base import BaseMemory from core.model_manager import ModelInstance, ModelManager @@ -62,6 +64,7 @@ from core.skill.entities.skill_document import SkillDocument from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency from core.skill.skill_compiler import SkillCompiler from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeFrom from core.tools.signature import sign_upload_file from core.tools.tool_manager import ToolManager from core.variables import ( @@ -355,6 +358,8 @@ class LLMNode(Node[LLMNodeData]): reasoning_content = "" usage = generation_data.usage finish_reason = generation_data.finish_reason + if generation_data.structured_output: + structured_output = generation_data.structured_output # Unified process_data building process_data = { @@ -1900,7 +1905,7 @@ class LLMNode(Node[LLMNodeData]): prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, stop=list(stop or []), - stream=True, + stream=False, ) result = yield from self._process_tool_outputs(outputs) @@ -1921,11 +1926,22 @@ class LLMNode(Node[LLMNodeData]): with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session: prompt_files = self._extract_prompt_files(variable_pool) model_features = self._get_model_features(model_instance) + structured_output_schema = None + if self._node_data.structured_output_enabled: + structured_output_schema = LLMNode.fetch_structured_output_schema( + structured_output=self._node_data.structured_output or {}, + ) + output_tools = build_agent_output_tools( + tenant_id=self.tenant_id, + invoke_from=self.invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + structured_output_schema=structured_output_schema, + ) strategy = StrategyFactory.create_strategy( model_features=model_features, model_instance=model_instance, - tools=[session.bash_tool], + tools=[session.bash_tool, *output_tools], files=prompt_files, max_iterations=self._node_data.max_iterations or 100, agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, @@ -1936,7 +1952,7 @@ class LLMNode(Node[LLMNodeData]): prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, stop=list(stop or []), - stream=True, + stream=False, ) result = yield from self._process_tool_outputs(outputs) @@ -2011,6 +2027,20 @@ class LLMNode(Node[LLMNodeData]): logger.warning("Failed to load tool %s: %s", tool, str(e)) continue + structured_output_schema = None + if self._node_data.structured_output_enabled: + structured_output_schema = LLMNode.fetch_structured_output_schema( + structured_output=self._node_data.structured_output or {}, + ) + tool_instances.extend( + build_agent_output_tools( + tenant_id=self.tenant_id, + invoke_from=self.invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + structured_output_schema=structured_output_schema, + ) + ) + return tool_instances def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]: @@ -2480,6 +2510,7 @@ class LLMNode(Node[LLMNodeData]): finish_reason=aggregate.finish_reason, files=aggregate.files, trace=trace_state.trace_segments, + structured_output=aggregate.structured_output, ) def _process_tool_outputs( @@ -2494,19 +2525,54 @@ class LLMNode(Node[LLMNodeData]): if isinstance(output, AgentLog): yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent) else: - yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate) + continue except StopIteration as exception: if isinstance(getattr(exception, "value", None), AgentResult): state.agent.agent_result = exception.value if state.agent.agent_result: - state.aggregate.text = state.agent.agent_result.text or state.aggregate.text + output_payload = state.agent.agent_result.output + structured_output_data: Mapping[str, Any] | None = None + if isinstance(output_payload, AgentResult.StructuredOutput): + output_kind = output_payload.output_kind + if output_kind == AgentOutputKind.ILLEGAL_OUTPUT: + raise ValueError("Agent returned illegal output") + if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}: + if not output_payload.output_text: + raise ValueError("Agent returned empty text output") + state.aggregate.text = output_payload.output_text + elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT: + if output_payload.output_data is None: + raise ValueError("Agent returned empty structured output") + else: + raise ValueError("Agent returned unsupported output kind") + + if output_payload.output_data is not None: + if not isinstance(output_payload.output_data, Mapping): + raise ValueError("Agent returned invalid structured output") + structured_output_data = output_payload.output_data + else: + if not output_payload: + raise ValueError("Agent returned empty output") + state.aggregate.text = str(output_payload) + state.aggregate.files = state.agent.agent_result.files if state.agent.agent_result.usage: state.aggregate.usage = state.agent.agent_result.usage if state.agent.agent_result.finish_reason: state.aggregate.finish_reason = state.agent.agent_result.finish_reason + if structured_output_data is not None: + output_schema = LLMNode.fetch_structured_output_schema( + structured_output=self._node_data.structured_output or {}, + ) + converted_output = convert_file_refs_in_output( + output=structured_output_data, + json_schema=output_schema, + tenant_id=self.tenant_id, + ) + state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output) + yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) yield from self._close_streams() diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py index d9301ccfe0..8214a56d3f 100644 --- a/api/tests/unit_tests/core/agent/test_agent_app_runner.py +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult +from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentPromptEntity, AgentResult from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.llm_entities import LLMUsage @@ -329,13 +329,20 @@ class TestAgentLogProcessing: ) result = AgentResult( - text="Final answer", + output=AgentResult.StructuredOutput( + output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, + output_text="Final answer", + output_data=None, + ), files=[], usage=usage, finish_reason="stop", ) - assert result.text == "Final answer" + output_payload = result.output + assert isinstance(output_payload, AgentResult.StructuredOutput) + assert output_payload.output_text == "Final answer" + assert output_payload.output_kind == AgentOutputKind.FINAL_OUTPUT_ANSWER assert result.files == [] assert result.usage == usage assert result.finish_reason == "stop"