mirror of https://github.com/langgenius/dify.git
Merge f4100bfb20 into e3c1112b15
This commit is contained in:
commit
a0d57a5ed7
|
|
@ -297,12 +297,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
# Paused nodes may already have partial outputs/metadata (for example agent clarification context).
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error="",
|
||||
update_outputs=False,
|
||||
update_outputs=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -146,9 +146,12 @@ class AgentNode(Node[AgentNodeData]):
|
|||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_execution_id=self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
node_title=self.title,
|
||||
node_execution_id=self.execution_id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,138 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import AgentLogEvent, NodeRunResult, PauseRequestedEvent
|
||||
from dify_graph.nodes.human_input.entities import HumanInputNodeData
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository
|
||||
|
||||
from .entities import AgentClarificationPayload
|
||||
|
||||
|
||||
def _default_form_repository_factory(tenant_id: str) -> HumanInputFormRepository:
|
||||
return HumanInputFormRepositoryImpl(tenant_id=tenant_id)
|
||||
|
||||
|
||||
class AgentClarificationHelper:
|
||||
"""Translate agent clarification payloads into standard workflow pause events."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
form_repository_factory: Callable[[str], HumanInputFormRepository] | None = None,
|
||||
) -> None:
|
||||
self._form_repository_factory = form_repository_factory or _default_form_repository_factory
|
||||
|
||||
def extract_payload(self, json_object: Mapping[str, Any]) -> AgentClarificationPayload | None:
|
||||
raw_payload = json_object.get("human_required") or json_object.get("clarification")
|
||||
if raw_payload is None:
|
||||
return None
|
||||
return AgentClarificationPayload.model_validate(raw_payload)
|
||||
|
||||
def build_pause_event(
|
||||
self,
|
||||
*,
|
||||
payload: AgentClarificationPayload,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_execution_id: str | None,
|
||||
node_id: str,
|
||||
node_title: str,
|
||||
node_execution_id: str,
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: Mapping[str, Any],
|
||||
partial_outputs: Mapping[str, Any],
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any],
|
||||
llm_usage: LLMUsage,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> PauseRequestedEvent:
|
||||
form_config = payload.to_human_input_node_data(node_title=node_title)
|
||||
form_entity = self._form_repository_factory(tenant_id).create_form(
|
||||
FormCreateParams(
|
||||
app_id=app_id,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
node_id=node_id,
|
||||
form_config=form_config,
|
||||
rendered_content=form_config.form_content,
|
||||
delivery_methods=form_config.delivery_methods,
|
||||
display_in_ui=payload.display_in_ui,
|
||||
resolved_default_values={},
|
||||
# Match HumanInputNode's baseline behavior so non-UI clarifications are still recoverable in Console.
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
)
|
||||
|
||||
pause_info = self._build_pause_info(
|
||||
payload=payload,
|
||||
form_config=form_config,
|
||||
form_id=form_entity.id,
|
||||
form_token=form_entity.web_app_token,
|
||||
node_id=node_id,
|
||||
node_execution_id=node_execution_id,
|
||||
node_title=node_title,
|
||||
tool_info=tool_info,
|
||||
)
|
||||
pause_metadata = {
|
||||
**execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
WorkflowNodeExecutionMetadataKey.PAUSE_INFO: pause_info,
|
||||
}
|
||||
|
||||
return PauseRequestedEvent(
|
||||
reason=HumanInputRequired(
|
||||
form_id=form_entity.id,
|
||||
form_content=form_entity.rendered_content,
|
||||
inputs=form_config.inputs,
|
||||
actions=form_config.user_actions,
|
||||
display_in_ui=payload.display_in_ui,
|
||||
node_id=node_id,
|
||||
node_title=node_title,
|
||||
form_token=form_entity.web_app_token,
|
||||
resolved_default_values={},
|
||||
),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.PAUSED,
|
||||
inputs=parameters_for_log,
|
||||
outputs={**partial_outputs, "clarification": pause_info},
|
||||
metadata=pause_metadata,
|
||||
llm_usage=llm_usage,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_pause_info(
|
||||
*,
|
||||
payload: AgentClarificationPayload,
|
||||
form_config: HumanInputNodeData,
|
||||
form_id: str,
|
||||
form_token: str | None,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
node_title: str,
|
||||
tool_info: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
required_fields = [
|
||||
{"name": field.name, "type": field.type.value} for field in payload.normalized_required_fields()
|
||||
]
|
||||
return {
|
||||
"type": "agent_clarification",
|
||||
"human_required": True,
|
||||
"resumable": True,
|
||||
"question": payload.question,
|
||||
"required_fields": required_fields,
|
||||
"form_id": form_id,
|
||||
"form_token": form_token,
|
||||
"form_content": payload.to_form_content(),
|
||||
"display_in_ui": payload.display_in_ui,
|
||||
"node_id": node_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"node_title": node_title,
|
||||
"agent_strategy": tool_info.get("agent_strategy"),
|
||||
"actions": [action.model_dump(mode="json") for action in form_config.user_actions],
|
||||
}
|
||||
|
|
@ -1,12 +1,24 @@
|
|||
import re
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
FormInput,
|
||||
HumanInputNodeData,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import ButtonStyle, FormInputType
|
||||
|
||||
_OUTPUT_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
|
|
@ -27,6 +39,70 @@ class AgentNodeData(BaseNodeData):
|
|||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class AgentClarificationField(BaseModel):
|
||||
name: str
|
||||
type: FormInputType = FormInputType.TEXT_INPUT
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def _validate_name(cls, value: str) -> str:
|
||||
if not _OUTPUT_IDENTIFIER_PATTERN.match(value):
|
||||
raise PydanticCustomError(
|
||||
"agent_clarification_field_name",
|
||||
"field name must start with a letter or underscore and contain only letters, numbers, or underscores",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _default_clarification_actions() -> list[UserAction]:
|
||||
return [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)]
|
||||
|
||||
|
||||
class AgentClarificationPayload(BaseModel):
|
||||
"""Minimal clarification contract emitted by agent strategies to request human input."""
|
||||
|
||||
question: str
|
||||
required_fields: list[str | AgentClarificationField] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
|
||||
@field_validator("question")
|
||||
@classmethod
|
||||
def _validate_question(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise PydanticCustomError("agent_clarification_question", "question must not be empty")
|
||||
return value
|
||||
|
||||
def normalized_required_fields(self) -> list[AgentClarificationField]:
|
||||
normalized: list[AgentClarificationField] = []
|
||||
for field in self.required_fields:
|
||||
if isinstance(field, AgentClarificationField):
|
||||
normalized.append(field)
|
||||
else:
|
||||
normalized.append(AgentClarificationField(name=field))
|
||||
return normalized
|
||||
|
||||
def to_form_inputs(self) -> list[FormInput]:
|
||||
return [
|
||||
FormInput(type=field.type, output_variable_name=field.name) for field in self.normalized_required_fields()
|
||||
]
|
||||
|
||||
def to_form_content(self) -> str:
|
||||
lines = [self.question.strip()]
|
||||
for field in self.normalized_required_fields():
|
||||
lines.append(f"- `{field.name}`: {{{{#$output.{field.name}#}}}}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
def to_human_input_node_data(self, *, node_title: str) -> HumanInputNodeData:
|
||||
delivery_methods: list[DeliveryChannelConfig] = [WebAppDeliveryMethod()] if self.display_in_ui else []
|
||||
return HumanInputNodeData(
|
||||
title=node_title,
|
||||
form_content=self.to_form_content(),
|
||||
inputs=self.to_form_inputs(),
|
||||
user_actions=_default_clarification_actions(),
|
||||
delivery_methods=delivery_methods,
|
||||
)
|
||||
|
||||
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
|
|
|||
|
|
@ -25,10 +25,14 @@ from factories import file_factory
|
|||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .clarification_helper import AgentClarificationHelper
|
||||
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
|
||||
|
||||
|
||||
class AgentMessageTransformer:
|
||||
def __init__(self, *, clarification_helper: AgentClarificationHelper | None = None) -> None:
|
||||
self._clarification_helper = clarification_helper or AgentClarificationHelper()
|
||||
|
||||
def transform(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -37,8 +41,11 @@ class AgentMessageTransformer:
|
|||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_execution_id: str | None,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_title: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
|
@ -123,20 +130,52 @@ class AgentMessageTransformer:
|
|||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
json_object = message.message.json_object
|
||||
if node_type == BuiltinNodeTypes.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
if isinstance(json_object, dict):
|
||||
json_object = dict(json_object)
|
||||
msg_metadata: dict[str, Any] = json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
clarification_payload = self._clarification_helper.extract_payload(json_object)
|
||||
if clarification_payload is not None:
|
||||
json_object.pop("human_required", None)
|
||||
json_object.pop("clarification", None)
|
||||
if json_object:
|
||||
json_list.append(json_object)
|
||||
# A clarification payload turns the agent node into a paused node result,
|
||||
# so we must stop before emitting the normal success completion event.
|
||||
yield self._clarification_helper.build_pause_event(
|
||||
payload=clarification_payload,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
node_id=node_id,
|
||||
node_title=node_title,
|
||||
node_execution_id=node_execution_id,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
partial_outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": self._build_json_output(agent_logs=agent_logs, json_list=json_list),
|
||||
**variables,
|
||||
},
|
||||
execution_metadata=agent_execution_metadata,
|
||||
llm_usage=llm_usage,
|
||||
agent_logs=agent_logs,
|
||||
)
|
||||
return
|
||||
else:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
if json_object:
|
||||
json_list.append(json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
|
|
@ -238,25 +277,7 @@ class AgentMessageTransformer:
|
|||
|
||||
yield agent_log
|
||||
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
json_output = self._build_json_output(agent_logs=agent_logs, json_list=json_list)
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
|
|
@ -290,3 +311,30 @@ class AgentMessageTransformer:
|
|||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_json_output(
|
||||
*,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
json_list: list[dict[str, Any] | list[Any]],
|
||||
) -> list[dict[str, Any] | list[Any]]:
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
return json_output
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
PAUSE_INFO = "pause_info" # structured pause payload for resumable nodes
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any
|
|||
from pydantic import Field
|
||||
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
|
|
@ -47,6 +48,10 @@ class StreamCompletedEvent(NodeEventBase):
|
|||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
node_run_result: NodeRunResult = Field(
|
||||
default_factory=lambda: NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
description="partial node result persisted when the node pauses",
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormFilledEvent(NodeEventBase):
|
||||
|
|
|
|||
|
|
@ -639,11 +639,12 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
@_dispatch.register
|
||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||
node_run_result = event.node_run_result.model_copy(update={"status": WorkflowNodeExecutionStatus.PAUSED})
|
||||
return NodeRunPauseRequestedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
node_run_result=node_run_result,
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,244 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.agent.clarification_helper import AgentClarificationHelper
|
||||
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph_events import NodeRunPauseRequestedEvent, NodeRunSucceededEvent
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
||||
|
||||
class _FakeFormEntity:
|
||||
id = "form-1"
|
||||
web_app_token = "token-1"
|
||||
recipients = []
|
||||
rendered_content = "Please provide the missing customer id.\n\n- `customer_id`: {{#$output.customer_id#}}"
|
||||
selected_action_id = None
|
||||
submitted_data = None
|
||||
submitted = False
|
||||
status = HumanInputFormStatus.WAITING
|
||||
expiration_time = datetime(2030, 1, 1, tzinfo=UTC)
|
||||
|
||||
|
||||
class _FakeFormRepository:
|
||||
def __init__(self) -> None:
|
||||
self.last_params = None
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str):
|
||||
return None
|
||||
|
||||
def create_form(self, params):
|
||||
self.last_params = params
|
||||
return _FakeFormEntity()
|
||||
|
||||
|
||||
class _FakeStrategy:
|
||||
def __init__(self, messages: list[ToolInvokeMessage]) -> None:
|
||||
self._messages = messages
|
||||
|
||||
def get_parameters(self):
|
||||
return []
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
credentials: object | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
_ = (params, user_id, conversation_id, app_id, message_id, credentials)
|
||||
yield from self._messages
|
||||
|
||||
|
||||
class _FakeStrategyResolver:
|
||||
def __init__(self, strategy: _FakeStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> _FakeStrategy:
|
||||
_ = (tenant_id, agent_strategy_provider_name, agent_strategy_name)
|
||||
return self._strategy
|
||||
|
||||
|
||||
class _FakePresentationProvider:
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str:
|
||||
_ = (tenant_id, agent_strategy_provider_name)
|
||||
return "icon.svg"
|
||||
|
||||
|
||||
class _FakeRuntimeSupport:
|
||||
def build_parameters(self, *, for_log: bool = False, **_: Any) -> dict[str, Any]:
|
||||
return {"query": "Need clarification"} if for_log else {"query": "Need clarification"}
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> object:
|
||||
_ = parameters
|
||||
return object()
|
||||
|
||||
|
||||
def _build_agent_node(
|
||||
*,
|
||||
messages: list[ToolInvokeMessage],
|
||||
form_repository: _FakeFormRepository,
|
||||
) -> AgentNode:
|
||||
graph_config: dict[str, Any] = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "agent-node",
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.AGENT,
|
||||
"title": "Agent Node",
|
||||
"desc": "",
|
||||
"agent_strategy_provider_name": "provider",
|
||||
"agent_strategy_name": "strategy",
|
||||
"agent_strategy_label": "Strategy",
|
||||
"agent_parameters": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
init_params = GraphInitParams(
|
||||
workflow_id="workflow-id",
|
||||
graph_config=graph_config,
|
||||
run_context={
|
||||
"_dify": {
|
||||
"tenant_id": "tenant-id",
|
||||
"app_id": "app-id",
|
||||
"user_id": "user-id",
|
||||
"user_from": "account",
|
||||
"invoke_from": "debugger",
|
||||
}
|
||||
},
|
||||
call_depth=0,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user-id",
|
||||
app_id="app-id",
|
||||
workflow_execution_id="workflow-run-id",
|
||||
)
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
clarification_helper = AgentClarificationHelper(form_repository_factory=lambda _tenant_id: form_repository)
|
||||
return AgentNode(
|
||||
id="agent-node",
|
||||
config=graph_config["nodes"][0],
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
strategy_resolver=_FakeStrategyResolver(_FakeStrategy(messages)),
|
||||
presentation_provider=_FakePresentationProvider(),
|
||||
runtime_support=_FakeRuntimeSupport(),
|
||||
message_transformer=AgentMessageTransformer(clarification_helper=clarification_helper),
|
||||
)
|
||||
|
||||
|
||||
def test_agent_node_clarification_payload_pauses_workflow() -> None:
|
||||
form_repository = _FakeFormRepository()
|
||||
node = _build_agent_node(
|
||||
form_repository=form_repository,
|
||||
messages=[
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="Need more context. "),
|
||||
),
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON,
|
||||
message=ToolInvokeMessage.JsonMessage(
|
||||
json_object={
|
||||
"human_required": {
|
||||
"question": "Please provide the missing customer id.",
|
||||
"required_fields": ["customer_id"],
|
||||
"display_in_ui": True,
|
||||
},
|
||||
"execution_metadata": {
|
||||
"total_tokens": 12,
|
||||
"total_price": 0,
|
||||
"currency": "USD",
|
||||
},
|
||||
}
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
events = list(node.run())
|
||||
|
||||
pause_event = next(event for event in events if isinstance(event, NodeRunPauseRequestedEvent))
|
||||
assert pause_event.node_run_result.status == WorkflowNodeExecutionStatus.PAUSED
|
||||
assert pause_event.reason.form_id == "form-1"
|
||||
assert pause_event.reason.node_id == "agent-node"
|
||||
assert pause_event.node_run_result.outputs["text"] == "Need more context. "
|
||||
assert pause_event.node_run_result.outputs["clarification"]["question"] == "Please provide the missing customer id."
|
||||
assert pause_event.node_run_result.outputs["clarification"]["agent_strategy"] == "strategy"
|
||||
assert pause_event.node_run_result.metadata[WorkflowNodeExecutionMetadataKey.PAUSE_INFO]["form_id"] == "form-1"
|
||||
assert form_repository.last_params is not None
|
||||
assert form_repository.last_params.workflow_execution_id == "workflow-run-id"
|
||||
assert form_repository.last_params.node_id == "agent-node"
|
||||
assert form_repository.last_params.backstage_recipient_required is True
|
||||
assert not any(isinstance(event, NodeRunSucceededEvent) for event in events)
|
||||
|
||||
|
||||
def test_message_transformer_keeps_success_path_without_clarification_payload() -> None:
|
||||
transformer = AgentMessageTransformer(
|
||||
clarification_helper=AgentClarificationHelper(form_repository_factory=lambda _tenant_id: _FakeFormRepository())
|
||||
)
|
||||
|
||||
events = list(
|
||||
transformer.transform(
|
||||
messages=iter(
|
||||
[
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="Final answer"),
|
||||
),
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON,
|
||||
message=ToolInvokeMessage.JsonMessage(
|
||||
json_object={
|
||||
"answer": {"ok": True},
|
||||
"execution_metadata": {
|
||||
"total_tokens": 9,
|
||||
"total_price": 0,
|
||||
"currency": "USD",
|
||||
},
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
tool_info={"icon": "icon.svg", "agent_strategy": "strategy"},
|
||||
parameters_for_log={"query": "Need clarification"},
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_execution_id="workflow-run-id",
|
||||
node_type=BuiltinNodeTypes.AGENT,
|
||||
node_id="agent-node",
|
||||
node_title="Agent Node",
|
||||
node_execution_id="exec-1",
|
||||
)
|
||||
)
|
||||
|
||||
completed_event = events[-1]
|
||||
assert isinstance(completed_event, StreamCompletedEvent)
|
||||
assert completed_event.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert completed_event.node_run_result.outputs["text"] == "Final answer"
|
||||
assert "clarification" not in completed_event.node_run_result.outputs
|
||||
assert completed_event.node_run_result.metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] == 9
|
||||
Loading…
Reference in New Issue