This commit is contained in:
Haohao 2026-03-24 20:57:43 +08:00 committed by GitHub
commit a0d57a5ed7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 544 additions and 28 deletions

View File

@ -297,12 +297,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
# Paused nodes may already have partial outputs/metadata (for example agent clarification context).
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.PAUSED,
error="",
update_outputs=False,
update_outputs=True,
)
# ------------------------------------------------------------------

View File

@ -146,9 +146,12 @@ class AgentNode(Node[AgentNodeData]):
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
workflow_execution_id=self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
node_title=self.title,
node_execution_id=self.execution_id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(

View File

@ -0,0 +1,138 @@
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import Any
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import AgentLogEvent, NodeRunResult, PauseRequestedEvent
from dify_graph.nodes.human_input.entities import HumanInputNodeData
from dify_graph.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository
from .entities import AgentClarificationPayload
def _default_form_repository_factory(tenant_id: str) -> HumanInputFormRepository:
return HumanInputFormRepositoryImpl(tenant_id=tenant_id)
class AgentClarificationHelper:
"""Translate agent clarification payloads into standard workflow pause events."""
def __init__(
self,
*,
form_repository_factory: Callable[[str], HumanInputFormRepository] | None = None,
) -> None:
self._form_repository_factory = form_repository_factory or _default_form_repository_factory
def extract_payload(self, json_object: Mapping[str, Any]) -> AgentClarificationPayload | None:
raw_payload = json_object.get("human_required") or json_object.get("clarification")
if raw_payload is None:
return None
return AgentClarificationPayload.model_validate(raw_payload)
def build_pause_event(
self,
*,
payload: AgentClarificationPayload,
tenant_id: str,
app_id: str,
workflow_execution_id: str | None,
node_id: str,
node_title: str,
node_execution_id: str,
tool_info: Mapping[str, Any],
parameters_for_log: Mapping[str, Any],
partial_outputs: Mapping[str, Any],
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any],
llm_usage: LLMUsage,
agent_logs: list[AgentLogEvent],
) -> PauseRequestedEvent:
form_config = payload.to_human_input_node_data(node_title=node_title)
form_entity = self._form_repository_factory(tenant_id).create_form(
FormCreateParams(
app_id=app_id,
workflow_execution_id=workflow_execution_id,
node_id=node_id,
form_config=form_config,
rendered_content=form_config.form_content,
delivery_methods=form_config.delivery_methods,
display_in_ui=payload.display_in_ui,
resolved_default_values={},
# Match HumanInputNode's baseline behavior so non-UI clarifications are still recoverable in Console.
backstage_recipient_required=True,
)
)
pause_info = self._build_pause_info(
payload=payload,
form_config=form_config,
form_id=form_entity.id,
form_token=form_entity.web_app_token,
node_id=node_id,
node_execution_id=node_execution_id,
node_title=node_title,
tool_info=tool_info,
)
pause_metadata = {
**execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
WorkflowNodeExecutionMetadataKey.PAUSE_INFO: pause_info,
}
return PauseRequestedEvent(
reason=HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=form_config.inputs,
actions=form_config.user_actions,
display_in_ui=payload.display_in_ui,
node_id=node_id,
node_title=node_title,
form_token=form_entity.web_app_token,
resolved_default_values={},
),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.PAUSED,
inputs=parameters_for_log,
outputs={**partial_outputs, "clarification": pause_info},
metadata=pause_metadata,
llm_usage=llm_usage,
),
)
@staticmethod
def _build_pause_info(
*,
payload: AgentClarificationPayload,
form_config: HumanInputNodeData,
form_id: str,
form_token: str | None,
node_id: str,
node_execution_id: str,
node_title: str,
tool_info: Mapping[str, Any],
) -> dict[str, Any]:
required_fields = [
{"name": field.name, "type": field.type.value} for field in payload.normalized_required_fields()
]
return {
"type": "agent_clarification",
"human_required": True,
"resumable": True,
"question": payload.question,
"required_fields": required_fields,
"form_id": form_id,
"form_token": form_token,
"form_content": payload.to_form_content(),
"display_in_ui": payload.display_in_ui,
"node_id": node_id,
"node_execution_id": node_execution_id,
"node_title": node_title,
"agent_strategy": tool_info.get("agent_strategy"),
"actions": [action.model_dump(mode="json") for action in form_config.user_actions],
}

View File

@ -1,12 +1,24 @@
import re
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticCustomError
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.human_input.entities import (
DeliveryChannelConfig,
FormInput,
HumanInputNodeData,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import ButtonStyle, FormInputType
_OUTPUT_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
class AgentNodeData(BaseNodeData):
@ -27,6 +39,70 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]
class AgentClarificationField(BaseModel):
name: str
type: FormInputType = FormInputType.TEXT_INPUT
@field_validator("name")
@classmethod
def _validate_name(cls, value: str) -> str:
if not _OUTPUT_IDENTIFIER_PATTERN.match(value):
raise PydanticCustomError(
"agent_clarification_field_name",
"field name must start with a letter or underscore and contain only letters, numbers, or underscores",
)
return value
def _default_clarification_actions() -> list[UserAction]:
return [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)]
class AgentClarificationPayload(BaseModel):
"""Minimal clarification contract emitted by agent strategies to request human input."""
question: str
required_fields: list[str | AgentClarificationField] = Field(default_factory=list)
display_in_ui: bool = False
@field_validator("question")
@classmethod
def _validate_question(cls, value: str) -> str:
if not value.strip():
raise PydanticCustomError("agent_clarification_question", "question must not be empty")
return value
def normalized_required_fields(self) -> list[AgentClarificationField]:
normalized: list[AgentClarificationField] = []
for field in self.required_fields:
if isinstance(field, AgentClarificationField):
normalized.append(field)
else:
normalized.append(AgentClarificationField(name=field))
return normalized
def to_form_inputs(self) -> list[FormInput]:
return [
FormInput(type=field.type, output_variable_name=field.name) for field in self.normalized_required_fields()
]
def to_form_content(self) -> str:
lines = [self.question.strip()]
for field in self.normalized_required_fields():
lines.append(f"- `{field.name}`: {{{{#$output.{field.name}#}}}}")
return "\n\n".join(lines)
def to_human_input_node_data(self, *, node_title: str) -> HumanInputNodeData:
delivery_methods: list[DeliveryChannelConfig] = [WebAppDeliveryMethod()] if self.display_in_ui else []
return HumanInputNodeData(
title=node_title,
form_content=self.to_form_content(),
inputs=self.to_form_inputs(),
user_actions=_default_clarification_actions(),
delivery_methods=delivery_methods,
)
class ParamsAutoGenerated(IntEnum):
CLOSE = 0
OPEN = 1

View File

@ -25,10 +25,14 @@ from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .clarification_helper import AgentClarificationHelper
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
class AgentMessageTransformer:
def __init__(self, *, clarification_helper: AgentClarificationHelper | None = None) -> None:
self._clarification_helper = clarification_helper or AgentClarificationHelper()
def transform(
self,
*,
@ -37,8 +41,11 @@ class AgentMessageTransformer:
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
app_id: str,
workflow_execution_id: str | None,
node_type: NodeType,
node_id: str,
node_title: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.plugin import PluginInstaller
@ -123,20 +130,52 @@ class AgentMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
json_object = message.message.json_object
if node_type == BuiltinNodeTypes.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
if isinstance(json_object, dict):
json_object = dict(json_object)
msg_metadata: dict[str, Any] = json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
clarification_payload = self._clarification_helper.extract_payload(json_object)
if clarification_payload is not None:
json_object.pop("human_required", None)
json_object.pop("clarification", None)
if json_object:
json_list.append(json_object)
# A clarification payload turns the agent node into a paused node result,
# so we must stop before emitting the normal success completion event.
yield self._clarification_helper.build_pause_event(
payload=clarification_payload,
tenant_id=tenant_id,
app_id=app_id,
workflow_execution_id=workflow_execution_id,
node_id=node_id,
node_title=node_title,
node_execution_id=node_execution_id,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
partial_outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": self._build_json_output(agent_logs=agent_logs, json_list=json_list),
**variables,
},
execution_metadata=agent_execution_metadata,
llm_usage=llm_usage,
agent_logs=agent_logs,
)
return
else:
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
if json_object:
json_list.append(json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
@ -238,25 +277,7 @@ class AgentMessageTransformer:
yield agent_log
json_output: list[dict[str, Any] | list[Any]] = []
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
json_output = self._build_json_output(agent_logs=agent_logs, json_list=json_list)
yield StreamChunkEvent(
selector=[node_id, "text"],
@ -290,3 +311,30 @@ class AgentMessageTransformer:
llm_usage=llm_usage,
)
)
@staticmethod
def _build_json_output(
*,
agent_logs: list[AgentLogEvent],
json_list: list[dict[str, Any] | list[Any]],
) -> list[dict[str, Any] | list[Any]]:
json_output: list[dict[str, Any] | list[Any]] = []
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
return json_output

View File

@ -271,6 +271,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
DATASOURCE_INFO = "datasource_info"
TRIGGER_INFO = "trigger_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
PAUSE_INFO = "pause_info" # structured pause payload for resumable nodes
class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -5,6 +5,7 @@ from typing import Any
from pydantic import Field
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.file import File
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeRunResult
@ -47,6 +48,10 @@ class StreamCompletedEvent(NodeEventBase):
class PauseRequestedEvent(NodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
node_run_result: NodeRunResult = Field(
default_factory=lambda: NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
description="partial node result persisted when the node pauses",
)
class HumanInputFormFilledEvent(NodeEventBase):

View File

@ -639,11 +639,12 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
node_run_result = event.node_run_result.model_copy(update={"status": WorkflowNodeExecutionStatus.PAUSED})
return NodeRunPauseRequestedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
node_run_result=node_run_result,
reason=event.reason,
)

View File

@ -0,0 +1,244 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Any
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.agent.clarification_helper import AgentClarificationHelper
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.graph_events import NodeRunPauseRequestedEvent, NodeRunSucceededEvent
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
class _FakeFormEntity:
id = "form-1"
web_app_token = "token-1"
recipients = []
rendered_content = "Please provide the missing customer id.\n\n- `customer_id`: {{#$output.customer_id#}}"
selected_action_id = None
submitted_data = None
submitted = False
status = HumanInputFormStatus.WAITING
expiration_time = datetime(2030, 1, 1, tzinfo=UTC)
class _FakeFormRepository:
def __init__(self) -> None:
self.last_params = None
def get_form(self, workflow_execution_id: str, node_id: str):
return None
def create_form(self, params):
self.last_params = params
return _FakeFormEntity()
class _FakeStrategy:
def __init__(self, messages: list[ToolInvokeMessage]) -> None:
self._messages = messages
def get_parameters(self):
return []
def invoke(
self,
*,
params: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: object | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
_ = (params, user_id, conversation_id, app_id, message_id, credentials)
yield from self._messages
class _FakeStrategyResolver:
def __init__(self, strategy: _FakeStrategy) -> None:
self._strategy = strategy
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> _FakeStrategy:
_ = (tenant_id, agent_strategy_provider_name, agent_strategy_name)
return self._strategy
class _FakePresentationProvider:
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str:
_ = (tenant_id, agent_strategy_provider_name)
return "icon.svg"
class _FakeRuntimeSupport:
def build_parameters(self, *, for_log: bool = False, **_: Any) -> dict[str, Any]:
return {"query": "Need clarification"} if for_log else {"query": "Need clarification"}
def build_credentials(self, *, parameters: dict[str, Any]) -> object:
_ = parameters
return object()
def _build_agent_node(
*,
messages: list[ToolInvokeMessage],
form_repository: _FakeFormRepository,
) -> AgentNode:
graph_config: dict[str, Any] = {
"nodes": [
{
"id": "agent-node",
"data": {
"type": BuiltinNodeTypes.AGENT,
"title": "Agent Node",
"desc": "",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {},
},
}
],
"edges": [],
}
init_params = GraphInitParams(
workflow_id="workflow-id",
graph_config=graph_config,
run_context={
"_dify": {
"tenant_id": "tenant-id",
"app_id": "app-id",
"user_id": "user-id",
"user_from": "account",
"invoke_from": "debugger",
}
},
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user-id",
app_id="app-id",
workflow_execution_id="workflow-run-id",
)
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
clarification_helper = AgentClarificationHelper(form_repository_factory=lambda _tenant_id: form_repository)
return AgentNode(
id="agent-node",
config=graph_config["nodes"][0],
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
strategy_resolver=_FakeStrategyResolver(_FakeStrategy(messages)),
presentation_provider=_FakePresentationProvider(),
runtime_support=_FakeRuntimeSupport(),
message_transformer=AgentMessageTransformer(clarification_helper=clarification_helper),
)
def test_agent_node_clarification_payload_pauses_workflow() -> None:
form_repository = _FakeFormRepository()
node = _build_agent_node(
form_repository=form_repository,
messages=[
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="Need more context. "),
),
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(
json_object={
"human_required": {
"question": "Please provide the missing customer id.",
"required_fields": ["customer_id"],
"display_in_ui": True,
},
"execution_metadata": {
"total_tokens": 12,
"total_price": 0,
"currency": "USD",
},
}
),
),
],
)
events = list(node.run())
pause_event = next(event for event in events if isinstance(event, NodeRunPauseRequestedEvent))
assert pause_event.node_run_result.status == WorkflowNodeExecutionStatus.PAUSED
assert pause_event.reason.form_id == "form-1"
assert pause_event.reason.node_id == "agent-node"
assert pause_event.node_run_result.outputs["text"] == "Need more context. "
assert pause_event.node_run_result.outputs["clarification"]["question"] == "Please provide the missing customer id."
assert pause_event.node_run_result.outputs["clarification"]["agent_strategy"] == "strategy"
assert pause_event.node_run_result.metadata[WorkflowNodeExecutionMetadataKey.PAUSE_INFO]["form_id"] == "form-1"
assert form_repository.last_params is not None
assert form_repository.last_params.workflow_execution_id == "workflow-run-id"
assert form_repository.last_params.node_id == "agent-node"
assert form_repository.last_params.backstage_recipient_required is True
assert not any(isinstance(event, NodeRunSucceededEvent) for event in events)
def test_message_transformer_keeps_success_path_without_clarification_payload() -> None:
transformer = AgentMessageTransformer(
clarification_helper=AgentClarificationHelper(form_repository_factory=lambda _tenant_id: _FakeFormRepository())
)
events = list(
transformer.transform(
messages=iter(
[
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="Final answer"),
),
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(
json_object={
"answer": {"ok": True},
"execution_metadata": {
"total_tokens": 9,
"total_price": 0,
"currency": "USD",
},
}
),
),
]
),
tool_info={"icon": "icon.svg", "agent_strategy": "strategy"},
parameters_for_log={"query": "Need clarification"},
user_id="user-id",
tenant_id="tenant-id",
app_id="app-id",
workflow_execution_id="workflow-run-id",
node_type=BuiltinNodeTypes.AGENT,
node_id="agent-node",
node_title="Agent Node",
node_execution_id="exec-1",
)
)
completed_event = events[-1]
assert isinstance(completed_event, StreamCompletedEvent)
assert completed_event.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert completed_event.node_run_result.outputs["text"] == "Final answer"
assert "clarification" not in completed_event.node_run_result.outputs
assert completed_event.node_run_result.metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] == 9