diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f26e8c68e8..d1455a6e05 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -257,7 +257,10 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, invoke_from=self.application_generate_entity.invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, - output_tool_names=select_output_tool_names(structured_output_enabled=False), + output_tool_names=select_output_tool_names( + structured_output_enabled=False, + include_illegal_output=True, + ), ) for tool in output_tools: tool_instances[tool.entity.identity.name] = tool diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py index d98fa005a3..fbb8d304a5 100644 --- a/api/core/agent/patterns/base.py +++ b/api/core/agent/patterns/base.py @@ -10,6 +10,7 @@ from collections.abc import Callable, Generator from typing import TYPE_CHECKING, Any from core.agent.entities import AgentLog, AgentResult, ExecutionContext +from core.agent.output_tools import ILLEGAL_OUTPUT_TOOL from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities import ( @@ -465,6 +466,8 @@ class AgentPattern(ABC): """Convert tools to prompt message format.""" prompt_tools: list[PromptMessageTool] = [] for tool in self.tools: + if tool.entity.identity.name == ILLEGAL_OUTPUT_TOOL: + continue prompt_tools.append(tool.to_prompt_message_tool()) return prompt_tools diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index da7b6422b3..8da0308ba4 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -42,6 +42,7 @@ class FunctionCallStrategy(AgentPattern): """Execute the function call agent strategy.""" # Convert tools to prompt format prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() + tool_instance_names = {tool.entity.identity.name for tool in self.tools} available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET} if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names: terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL @@ -49,7 +50,7 @@ class FunctionCallStrategy(AgentPattern): terminal_tool_name = FINAL_OUTPUT_TOOL else: raise ValueError("No terminal output tool configured") - allow_illegal_output = ILLEGAL_OUTPUT_TOOL in available_output_tool_names + allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names # Initialize tracking iteration_step: int = 1 diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py index 14986093cf..ca4831b5c4 100644 --- a/api/core/agent/patterns/react.py +++ b/api/core/agent/patterns/react.py @@ -78,8 +78,11 @@ class ReActStrategy(AgentPattern): output_text_payload: str | None = None finish_reason: str | None = None terminal_output_seen = False + tool_instance_names = {tool.entity.identity.name for tool in self.tools} available_output_tool_names = { - tool.entity.identity.name for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET + tool_name + for tool_name in tool_instance_names + if tool_name in OUTPUT_TOOL_NAME_SET and tool_name != ILLEGAL_OUTPUT_TOOL } if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names: terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL @@ -87,7 +90,7 @@ class ReActStrategy(AgentPattern): terminal_tool_name = FINAL_OUTPUT_TOOL else: raise ValueError("No terminal output tool configured") - allow_illegal_output = ILLEGAL_OUTPUT_TOOL in available_output_tool_names + allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names # Add "Observation" to stop sequences if "Observation" not in stop: @@ -110,7 +113,7 @@ class ReActStrategy(AgentPattern): tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names ] else: - tools_for_prompt = self.tools + tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL] current_messages = self._build_prompt_with_react_format( prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction ) @@ -273,7 +276,11 @@ class ReActStrategy(AgentPattern): tool_names = [] if tools: # Convert tools to prompt message tools format - prompt_tools = [tool.to_prompt_message_tool() for tool in tools] + prompt_tools = [ + tool.to_prompt_message_tool() + for tool in tools + if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL + ] tool_names = [tool.name for tool in prompt_tools] # Format tools as JSON for comprehensive information diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index 46bfa04f54..6a266d1d9d 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, TypeVar from core.tools.entities.tool_entities import ToolInvokeMessage if TYPE_CHECKING: - from core.agent.entities import AgentInvokeMessage + pass MessageType = TypeVar("MessageType", bound=ToolInvokeMessage) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d93c0c9e5b..0b51163462 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -368,7 +368,7 @@ class ToolManager: app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: Optional[VariablePool] = None, ) -> Tool: """ get the agent tool runtime @@ -408,9 +408,9 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: Optional[VariablePool] = None, ) -> Tool: """ get the workflow tool runtime @@ -1017,7 +1017,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional["VariablePool"], + variable_pool: Optional[VariablePool], tool_configurations: dict[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 628aaaa334..2403b9e841 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1936,7 +1936,8 @@ class LLMNode(Node[LLMNodeData]): invoke_from=self.invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, output_tool_names=select_output_tool_names( - structured_output_enabled=self._node_data.structured_output_enabled + structured_output_enabled=self._node_data.structured_output_enabled, + include_illegal_output=True, ), structured_output_schema=structured_output_schema, ) @@ -2041,7 +2042,8 @@ class LLMNode(Node[LLMNodeData]): invoke_from=self.invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, output_tool_names=select_output_tool_names( - structured_output_enabled=self._node_data.structured_output_enabled + structured_output_enabled=self._node_data.structured_output_enabled, + include_illegal_output=True, ), structured_output_schema=structured_output_schema, )