mirror of https://github.com/langgenius/dify.git
fix: provides correct prompts, tools and terminal predicates
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
parent
ec9ade62f5
commit
22b0a08a5f
|
|
@ -7,7 +7,7 @@ from typing import Union, cast
|
|||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
|
@ -257,6 +257,7 @@ class BaseAgentRunner(AppRunner):
|
|||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
output_tool_names=select_output_tool_names(structured_output_enabled=False),
|
||||
)
|
||||
for tool in output_tools:
|
||||
tool_instances[tool.entity.identity.name] = tool
|
||||
|
|
|
|||
|
|
@ -25,15 +25,42 @@ OUTPUT_TOOL_NAMES: Sequence[str] = (
|
|||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
||||
|
||||
|
||||
def select_output_tool_names(
|
||||
*,
|
||||
structured_output_enabled: bool,
|
||||
include_illegal_output: bool = False,
|
||||
) -> list[str]:
|
||||
tool_names = [OUTPUT_TEXT_TOOL]
|
||||
if structured_output_enabled:
|
||||
tool_names.append(FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
else:
|
||||
tool_names.append(FINAL_OUTPUT_TOOL)
|
||||
if include_illegal_output:
|
||||
tool_names.append(ILLEGAL_OUTPUT_TOOL)
|
||||
return tool_names
|
||||
|
||||
|
||||
def select_terminal_tool_name(*, structured_output_enabled: bool) -> str:
|
||||
return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
|
||||
|
||||
|
||||
def build_agent_output_tools(
|
||||
*,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
output_tool_names: Sequence[str],
|
||||
structured_output_schema: dict[str, Any] | None = None,
|
||||
) -> list[Tool]:
|
||||
tools: list[Tool] = []
|
||||
for tool_name in OUTPUT_TOOL_NAMES:
|
||||
tool_names: list[str] = []
|
||||
for tool_name in output_tool_names:
|
||||
if tool_name not in OUTPUT_TOOL_NAME_SET:
|
||||
raise ValueError(f"Unknown output tool name: {tool_name}")
|
||||
if tool_name not in tool_names:
|
||||
tool_names.append(tool_name)
|
||||
|
||||
for tool_name in tool_names:
|
||||
tool = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id=OUTPUT_TOOL_PROVIDER,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,14 @@ class FunctionCallStrategy(AgentPattern):
|
|||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET}
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_OUTPUT_TOOL
|
||||
else:
|
||||
raise ValueError("No terminal output tool configured")
|
||||
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in available_output_tool_names
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
|
|
@ -54,6 +62,7 @@ class FunctionCallStrategy(AgentPattern):
|
|||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
terminal_output_seen = False
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
|
|
@ -79,7 +88,7 @@ class FunctionCallStrategy(AgentPattern):
|
|||
yield round_log
|
||||
# On last iteration, restrict tools to output tools
|
||||
if iteration_step == max_iterations:
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET]
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names]
|
||||
else:
|
||||
current_tools = prompt_tools
|
||||
model_log = self._create_log(
|
||||
|
|
@ -115,6 +124,8 @@ class FunctionCallStrategy(AgentPattern):
|
|||
)
|
||||
|
||||
if not tool_calls:
|
||||
if not allow_illegal_output:
|
||||
raise ValueError("Model did not call any tools")
|
||||
tool_calls = [
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
|
|
@ -149,9 +160,12 @@ class FunctionCallStrategy(AgentPattern):
|
|||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
data = tool_args.get("data")
|
||||
structured_output_payload = cast(dict[str, Any] | None, data)
|
||||
if tool_name == terminal_tool_name:
|
||||
terminal_tool_seen = True
|
||||
elif tool_name == FINAL_OUTPUT_TOOL:
|
||||
final_text = self._format_output_text(tool_args.get("text"))
|
||||
terminal_tool_seen = True
|
||||
if tool_name == terminal_tool_name:
|
||||
terminal_tool_seen = True
|
||||
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
|
|
@ -161,6 +175,7 @@ class FunctionCallStrategy(AgentPattern):
|
|||
output_files.extend(tool_files)
|
||||
|
||||
if terminal_tool_seen:
|
||||
terminal_output_seen = True
|
||||
function_call_state = False
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
|
|
@ -181,7 +196,13 @@ class FunctionCallStrategy(AgentPattern):
|
|||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
|
||||
output_text=None,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
|
|
|
|||
|
|
@ -77,6 +77,17 @@ class ReActStrategy(AgentPattern):
|
|||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
terminal_output_seen = False
|
||||
available_output_tool_names = {
|
||||
tool.entity.identity.name for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET
|
||||
}
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_OUTPUT_TOOL
|
||||
else:
|
||||
raise ValueError("No terminal output tool configured")
|
||||
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in available_output_tool_names
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
|
|
@ -95,7 +106,9 @@ class ReActStrategy(AgentPattern):
|
|||
|
||||
# Build prompt with tool restrictions on last iteration
|
||||
if iteration_step == max_iterations:
|
||||
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET]
|
||||
tools_for_prompt = [
|
||||
tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names
|
||||
]
|
||||
else:
|
||||
tools_for_prompt = self.tools
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
|
|
@ -150,6 +163,8 @@ class ReActStrategy(AgentPattern):
|
|||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action is None:
|
||||
if not allow_illegal_output:
|
||||
raise ValueError("Model did not call any tools")
|
||||
illegal_action = AgentScratchpadUnit.Action(
|
||||
action_name=ILLEGAL_OUTPUT_TOOL,
|
||||
action_input={"raw": scratchpad.thought or ""},
|
||||
|
|
@ -162,33 +177,29 @@ class ReActStrategy(AgentPattern):
|
|||
output_files.extend(tool_files)
|
||||
else:
|
||||
action_name = scratchpad.action.action_name
|
||||
if action_name == FINAL_OUTPUT_TOOL:
|
||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
output_text_payload = scratchpad.action.action_input.get("text")
|
||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
data = scratchpad.action.action_input.get("data")
|
||||
if isinstance(data, dict):
|
||||
structured_output_payload = data
|
||||
elif action_name == FINAL_OUTPUT_TOOL:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
|
||||
else:
|
||||
final_text = self._format_output_text(scratchpad.action.action_input)
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if action_name == terminal_tool_name:
|
||||
terminal_output_seen = True
|
||||
react_state = False
|
||||
else:
|
||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
output_text_payload = scratchpad.action.action_input.get("text")
|
||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(
|
||||
scratchpad.action.action_input, dict
|
||||
):
|
||||
data = scratchpad.action.action_input.get("data")
|
||||
if isinstance(data, dict):
|
||||
structured_output_payload = data
|
||||
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
|
|
@ -207,7 +218,13 @@ class ReActStrategy(AgentPattern):
|
|||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
|
||||
output_text=None,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
|
|
@ -268,12 +285,19 @@ class ReActStrategy(AgentPattern):
|
|||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
final_tool_name = FINAL_OUTPUT_TOOL
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names:
|
||||
final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
if final_tool_name not in tool_names:
|
||||
raise ValueError("No terminal output tool available for prompt")
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ You have access to the following tools:
|
|||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish.
|
||||
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ Thought: I know what to respond
|
|||
Action:
|
||||
```
|
||||
{
|
||||
"action": "final_output_answer",
|
||||
"action": "{{final_tool_name}}",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
|
|
@ -58,7 +58,7 @@ You have access to the following tools:
|
|||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish.
|
||||
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ Thought: I know what to respond
|
|||
Action:
|
||||
```
|
||||
{
|
||||
"action": "final_output_answer",
|
||||
"action": "{{final_tool_name}}",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ class IllegalOutputTool(BuiltinTool):
|
|||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
message = (
|
||||
"Protocol violation: do not output plain text. "
|
||||
"Call output_text, final_structured_output, then final_output_answer."
|
||||
"Call an output tool and finish with the configured terminal tool."
|
||||
)
|
||||
yield self.create_text_message(message)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
|||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
|
|
@ -1935,6 +1935,9 @@ class LLMNode(Node[LLMNodeData]):
|
|||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
output_tool_names=select_output_tool_names(
|
||||
structured_output_enabled=self._node_data.structured_output_enabled
|
||||
),
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
|
|
@ -2037,6 +2040,9 @@ class LLMNode(Node[LLMNodeData]):
|
|||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
output_tool_names=select_output_tool_names(
|
||||
structured_output_enabled=self._node_data.structured_output_enabled
|
||||
),
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue