mirror of https://github.com/langgenius/dify.git
feat: support structured output in sandbox and tool mode
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
parent
d3fc457331
commit
7926024569
|
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
|||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
|
|
@ -13,6 +13,7 @@ from core.model_runtime.entities import (
|
|||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
|
|
@ -106,7 +107,6 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
|
|
@ -133,17 +133,10 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
# No more expect streaming data
|
||||
continue
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
else:
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
|
|
@ -156,7 +149,6 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
|
|
@ -265,7 +257,22 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
output_payload = result.output
|
||||
if isinstance(output_payload, AgentResult.StructuredOutput):
|
||||
if output_payload.output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
|
||||
raise ValueError("Agent returned illegal output")
|
||||
if output_payload.output_kind not in {
|
||||
AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
AgentOutputKind.OUTPUT_TEXT,
|
||||
}:
|
||||
raise ValueError("Agent did not return text output")
|
||||
if not output_payload.output_text:
|
||||
raise ValueError("Agent returned empty text output")
|
||||
final_answer = output_payload.output_text
|
||||
else:
|
||||
if not output_payload:
|
||||
raise ValueError("Agent returned empty output")
|
||||
final_answer = str(output_payload)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
|
|
@ -282,6 +289,17 @@ class AgentAppRunner(BaseAgentRunner):
|
|||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
if False:
|
||||
yield LLMResultChunk(
|
||||
model="",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Union, cast
|
|||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
|
@ -36,6 +37,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
|
@ -251,6 +253,14 @@ class BaseAgentRunner(AppRunner):
|
|||
# save tool entity
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
)
|
||||
for tool in output_tools:
|
||||
tool_instances[tool.entity.identity.name] = tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.output_tools import FINAL_OUTPUT_TOOL
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
|
|
@ -41,9 +42,9 @@ class AgentScratchpadUnit(BaseModel):
|
|||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
action_input: dict[str, Any] | str
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert to dictionary.
|
||||
"""
|
||||
|
|
@ -62,9 +63,9 @@ class AgentScratchpadUnit(BaseModel):
|
|||
"""
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
if self.action is None:
|
||||
return False
|
||||
return self.action.action_name.lower() == FINAL_OUTPUT_TOOL
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
|
|
@ -125,7 +126,7 @@ class ExecutionContext(BaseModel):
|
|||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
def with_updates(self, **kwargs: Any) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
|
|
@ -178,12 +179,35 @@ class AgentLog(BaseModel):
|
|||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentOutputKind(StrEnum):
|
||||
"""
|
||||
Agent output kind.
|
||||
"""
|
||||
|
||||
OUTPUT_TEXT = "output_text"
|
||||
FINAL_OUTPUT_ANSWER = "final_output_answer"
|
||||
FINAL_STRUCTURED_OUTPUT = "final_structured_output"
|
||||
ILLEGAL_OUTPUT = "illegal_output"
|
||||
|
||||
|
||||
OutputKind = AgentOutputKind
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
class StructuredOutput(BaseModel):
|
||||
"""
|
||||
Structured output payload from output tools.
|
||||
"""
|
||||
|
||||
output_kind: AgentOutputKind
|
||||
output_text: str | None = None
|
||||
output_data: Mapping[str, Any] | None = None
|
||||
|
||||
output: str | StructuredOutput = Field(default="", description="The generated output")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
|
@ -10,46 +10,52 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
|||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
action_input = None
|
||||
if isinstance(action, str):
|
||||
def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name: str | None = None
|
||||
action_input: Any | None = None
|
||||
parsed_action: Any = action
|
||||
if isinstance(parsed_action, str):
|
||||
try:
|
||||
action = json.loads(action, strict=False)
|
||||
parsed_action = json.loads(parsed_action, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
return action or ""
|
||||
return parsed_action or ""
|
||||
|
||||
# cohere always returns a list
|
||||
if isinstance(action, list) and len(action) == 1:
|
||||
action = action[0]
|
||||
if isinstance(parsed_action, list):
|
||||
action_list: list[Any] = cast(list[Any], parsed_action)
|
||||
if len(action_list) == 1:
|
||||
parsed_action = action_list[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
if isinstance(parsed_action, dict):
|
||||
action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action)
|
||||
for key, value in action_dict.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
elif isinstance(value, str):
|
||||
action_name = value
|
||||
else:
|
||||
return json.dumps(parsed_action)
|
||||
|
||||
if action_name is not None and action_input is not None:
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json.dumps(action)
|
||||
return json.dumps(parsed_action)
|
||||
|
||||
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
|
||||
def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]:
|
||||
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
|
||||
if not blocks:
|
||||
return []
|
||||
try:
|
||||
json_blocks = []
|
||||
json_blocks: list[dict[str, Any] | list[Any]] = []
|
||||
for block in blocks:
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
json_blocks.append(json.loads(json_text, strict=False))
|
||||
return json_blocks
|
||||
except:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
code_block_cache = ""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
OUTPUT_TOOL_PROVIDER = "agent_output"
|
||||
|
||||
OUTPUT_TEXT_TOOL = "output_text"
|
||||
FINAL_OUTPUT_TOOL = "final_output_answer"
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output"
|
||||
ILLEGAL_OUTPUT_TOOL = "illegal_output"
|
||||
|
||||
OUTPUT_TOOL_NAMES: Sequence[str] = (
|
||||
OUTPUT_TEXT_TOOL,
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
)
|
||||
|
||||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
||||
|
||||
|
||||
def build_agent_output_tools(
|
||||
*,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
structured_output_schema: dict[str, Any] | None = None,
|
||||
) -> list[Tool]:
|
||||
tools: list[Tool] = []
|
||||
for tool_name in OUTPUT_TOOL_NAMES:
|
||||
tool = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id=OUTPUT_TOOL_PROVIDER,
|
||||
tool_name=tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
|
||||
if tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and structured_output_schema:
|
||||
tool.entity = tool.entity.model_copy(deep=True)
|
||||
for parameter in tool.entity.parameters:
|
||||
if parameter.name != "data":
|
||||
continue
|
||||
parameter.type = ToolParameter.ToolParameterType.OBJECT
|
||||
parameter.form = ToolParameter.ToolParameterForm.LLM
|
||||
parameter.required = True
|
||||
parameter.input_schema = structured_output_schema
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
|
@ -1,10 +1,18 @@
|
|||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Protocol, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -42,9 +50,24 @@ class FunctionCallStrategy(AgentPattern):
|
|||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
stream: Literal[False],
|
||||
user: str | None,
|
||||
callbacks: list[Any],
|
||||
) -> LLMResult: ...
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
|
|
@ -54,8 +77,11 @@ class FunctionCallStrategy(AgentPattern):
|
|||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
# On last iteration, restrict tools to output tools
|
||||
if iteration_step == max_iterations:
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET]
|
||||
else:
|
||||
current_tools = prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
|
|
@ -72,31 +98,41 @@ class FunctionCallStrategy(AgentPattern):
|
|||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
invoker = cast(_LLMInvoker, self.model_instance)
|
||||
chunks = invoker.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
chunks, round_usage, model_log, emit_chunks=False
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
if not tool_calls:
|
||||
tool_calls = [
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
]
|
||||
response_content = ""
|
||||
|
||||
messages.append(self._create_assistant_message("", tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
|
@ -105,14 +141,27 @@ class FunctionCallStrategy(AgentPattern):
|
|||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
terminal_tool_seen = False
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
if tool_name == OUTPUT_TEXT_TOOL:
|
||||
output_text_payload = self._format_output_text(tool_args.get("text"))
|
||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
data = tool_args.get("data")
|
||||
structured_output_payload = cast(dict[str, Any] | None, data)
|
||||
elif tool_name == FINAL_OUTPUT_TOOL:
|
||||
final_text = self._format_output_text(tool_args.get("text"))
|
||||
terminal_tool_seen = True
|
||||
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if terminal_tool_seen:
|
||||
function_call_state = False
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
|
|
@ -131,8 +180,28 @@ class FunctionCallStrategy(AgentPattern):
|
|||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=None,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=None,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
|
|
@ -143,6 +212,8 @@ class FunctionCallStrategy(AgentPattern):
|
|||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
|
|
@ -174,7 +245,8 @@ class FunctionCallStrategy(AgentPattern):
|
|||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
if emit_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
|
@ -189,11 +261,12 @@ class FunctionCallStrategy(AgentPattern):
|
|||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
if emit_chunks:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
|
|
@ -203,6 +276,14 @@ class FunctionCallStrategy(AgentPattern):
|
|||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,17 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
|
|
@ -67,6 +74,8 @@ class ReActStrategy(AgentPattern):
|
|||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
|
|
@ -84,10 +93,13 @@ class ReActStrategy(AgentPattern):
|
|||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
# Build prompt with tool restrictions on last iteration
|
||||
if iteration_step == max_iterations:
|
||||
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET]
|
||||
else:
|
||||
tools_for_prompt = self.tools
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
|
|
@ -109,18 +121,21 @@ class ReActStrategy(AgentPattern):
|
|||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
chunks = cast(
|
||||
Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
chunks, round_usage, model_log, current_messages, emit_chunks=False
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
|
|
@ -134,28 +149,46 @@ class ReActStrategy(AgentPattern):
|
|||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
if scratchpad.action is None:
|
||||
illegal_action = AgentScratchpadUnit.Action(
|
||||
action_name=ILLEGAL_OUTPUT_TOOL,
|
||||
action_input={"raw": scratchpad.thought or ""},
|
||||
)
|
||||
scratchpad.action = illegal_action
|
||||
scratchpad.action_str = json.dumps(illegal_action.to_dict())
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
action_name = scratchpad.action.action_name
|
||||
if action_name == FINAL_OUTPUT_TOOL:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
|
||||
else:
|
||||
final_text = self._format_output_text(scratchpad.action.action_input)
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
react_state = False
|
||||
else:
|
||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
output_text_payload = scratchpad.action.action_input.get("text")
|
||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(
|
||||
scratchpad.action.action_input, dict
|
||||
):
|
||||
data = scratchpad.action.action_input.get("data")
|
||||
if isinstance(data, dict):
|
||||
structured_output_payload = data
|
||||
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
|
|
@ -173,15 +206,38 @@ class ReActStrategy(AgentPattern):
|
|||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage"),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
tools: list[Tool] | None,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
|
|
@ -198,9 +254,9 @@ class ReActStrategy(AgentPattern):
|
|||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
if tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
|
|
@ -253,6 +309,8 @@ class ReActStrategy(AgentPattern):
|
|||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
|
|
@ -306,14 +364,16 @@ class ReActStrategy(AgentPattern):
|
|||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
|
|
@ -337,6 +397,14 @@ class ReActStrategy(AgentPattern):
|
|||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ You have access to the following tools:
|
|||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish.
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
|
|
@ -32,12 +32,14 @@ Thought: I know what to respond
|
|||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
"action": "final_output_answer",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
|
|
@ -56,7 +58,7 @@ You have access to the following tools:
|
|||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish.
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
|
|
@ -81,12 +83,14 @@ Thought: I know what to respond
|
|||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
"action": "final_output_answer",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
||||
<rect x="3" y="3" width="18" height="18" rx="4" stroke="#2F2F2F" stroke-width="1.5"/>
|
||||
<path d="M7 12h10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
|
||||
<path d="M12 7v10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 332 B |
|
|
@ -0,0 +1,8 @@
|
|||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AgentOutputProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
identity:
|
||||
author: Dify
|
||||
name: agent_output
|
||||
label:
|
||||
en_US: Agent Output
|
||||
description:
|
||||
en_US: Internal tools for agent output control.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class FinalOutputAnswerTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("Final answer recorded.")
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
identity:
|
||||
name: final_output_answer
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Final Output Answer
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool to deliver the final answer.
|
||||
llm: Use this tool when you are ready to provide the final answer.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
human_description:
|
||||
en_US: Final answer text.
|
||||
form: llm
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class FinalStructuredOutputTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("Structured output recorded.")
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
identity:
|
||||
name: final_structured_output
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Final Structured Output
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool to deliver structured output.
|
||||
llm: Use this tool to provide structured output data.
|
||||
parameters:
|
||||
- name: data
|
||||
type: object
|
||||
required: true
|
||||
label:
|
||||
en_US: Data
|
||||
human_description:
|
||||
en_US: Structured output data.
|
||||
form: llm
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class IllegalOutputTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
message = (
|
||||
"Protocol violation: do not output plain text. "
|
||||
"Call output_text, final_structured_output, then final_output_answer."
|
||||
)
|
||||
yield self.create_text_message(message)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
identity:
|
||||
name: illegal_output
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Illegal Output
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool for output protocol violations.
|
||||
llm: Use this tool to correct output protocol violations.
|
||||
parameters:
|
||||
- name: raw
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Raw Output
|
||||
human_description:
|
||||
en_US: Raw model output that violated the protocol.
|
||||
form: llm
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class OutputTextTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("Output recorded.")
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
identity:
|
||||
name: output_text
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Output Text
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool to store intermediate text output.
|
||||
llm: Use this tool to emit non-final text output.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
human_description:
|
||||
en_US: Output text.
|
||||
form: llm
|
||||
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid
|
|||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCall, ToolCallResult
|
||||
|
|
@ -156,6 +156,9 @@ class LLMGenerationData(BaseModel):
|
|||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
structured_output: LLMStructuredOutput | None = Field(
|
||||
default=None, description="Structured output from tool-only agent runs"
|
||||
)
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
|
|
@ -284,6 +287,7 @@ class AggregatedResult(BaseModel):
|
|||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
|
|
@ -20,6 +21,7 @@ from core.app_assets.constants import AppAssetsAttrs
|
|||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
|
|
@ -62,6 +64,7 @@ from core.skill.entities.skill_document import SkillDocument
|
|||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.skill.skill_compiler import SkillCompiler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables import (
|
||||
|
|
@ -355,6 +358,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||
reasoning_content = ""
|
||||
usage = generation_data.usage
|
||||
finish_reason = generation_data.finish_reason
|
||||
if generation_data.structured_output:
|
||||
structured_output = generation_data.structured_output
|
||||
|
||||
# Unified process_data building
|
||||
process_data = {
|
||||
|
|
@ -1900,7 +1905,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
|
|
@ -1921,11 +1926,22 @@ class LLMNode(Node[LLMNodeData]):
|
|||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
structured_output_schema = None
|
||||
if self._node_data.structured_output_enabled:
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=[session.bash_tool],
|
||||
tools=[session.bash_tool, *output_tools],
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
|
|
@ -1936,7 +1952,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
|
|
@ -2011,6 +2027,20 @@ class LLMNode(Node[LLMNodeData]):
|
|||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||
continue
|
||||
|
||||
structured_output_schema = None
|
||||
if self._node_data.structured_output_enabled:
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
tool_instances.extend(
|
||||
build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_instances
|
||||
|
||||
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
|
|
@ -2480,6 +2510,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
finish_reason=aggregate.finish_reason,
|
||||
files=aggregate.files,
|
||||
trace=trace_state.trace_segments,
|
||||
structured_output=aggregate.structured_output,
|
||||
)
|
||||
|
||||
def _process_tool_outputs(
|
||||
|
|
@ -2494,19 +2525,54 @@ class LLMNode(Node[LLMNodeData]):
|
|||
if isinstance(output, AgentLog):
|
||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||
else:
|
||||
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||
continue
|
||||
except StopIteration as exception:
|
||||
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||
state.agent.agent_result = exception.value
|
||||
|
||||
if state.agent.agent_result:
|
||||
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||
output_payload = state.agent.agent_result.output
|
||||
structured_output_data: Mapping[str, Any] | None = None
|
||||
if isinstance(output_payload, AgentResult.StructuredOutput):
|
||||
output_kind = output_payload.output_kind
|
||||
if output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
|
||||
raise ValueError("Agent returned illegal output")
|
||||
if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}:
|
||||
if not output_payload.output_text:
|
||||
raise ValueError("Agent returned empty text output")
|
||||
state.aggregate.text = output_payload.output_text
|
||||
elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT:
|
||||
if output_payload.output_data is None:
|
||||
raise ValueError("Agent returned empty structured output")
|
||||
else:
|
||||
raise ValueError("Agent returned unsupported output kind")
|
||||
|
||||
if output_payload.output_data is not None:
|
||||
if not isinstance(output_payload.output_data, Mapping):
|
||||
raise ValueError("Agent returned invalid structured output")
|
||||
structured_output_data = output_payload.output_data
|
||||
else:
|
||||
if not output_payload:
|
||||
raise ValueError("Agent returned empty output")
|
||||
state.aggregate.text = str(output_payload)
|
||||
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
if structured_output_data is not None:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
converted_output = convert_file_refs_in_output(
|
||||
output=structured_output_data,
|
||||
json_schema=output_schema,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output)
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentPromptEntity, AgentResult
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
|
@ -329,13 +329,20 @@ class TestAgentLogProcessing:
|
|||
)
|
||||
|
||||
result = AgentResult(
|
||||
text="Final answer",
|
||||
output=AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text="Final answer",
|
||||
output_data=None,
|
||||
),
|
||||
files=[],
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
assert result.text == "Final answer"
|
||||
output_payload = result.output
|
||||
assert isinstance(output_payload, AgentResult.StructuredOutput)
|
||||
assert output_payload.output_text == "Final answer"
|
||||
assert output_payload.output_kind == AgentOutputKind.FINAL_OUTPUT_ANSWER
|
||||
assert result.files == []
|
||||
assert result.usage == usage
|
||||
assert result.finish_reason == "stop"
|
||||
|
|
|
|||
Loading…
Reference in New Issue