From b1f8afb7276ce232770b33dce1db7de640ad08c9 Mon Sep 17 00:00:00 2001 From: Wenbo2105 Date: Mon, 16 Mar 2026 04:07:43 +0800 Subject: [PATCH 1/6] fix:http-request-retry-interval --- api/dify_graph/entities/base_node_data.py | 397 ++++++++-------- api/dify_graph/graph_engine/error_handler.py | 424 +++++++++--------- api/dify_graph/nodes/http_request/node.py | 17 + .../components/workflow/nodes/http/types.ts | 175 ++++---- 4 files changed, 539 insertions(+), 474 deletions(-) diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py index 47b37c9daf..88146e552b 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/dify_graph/entities/base_node_data.py @@ -1,178 +1,219 @@ -from __future__ import annotations - -import json -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from dify_graph.entities.exc import DefaultValueTypeError -from dify_graph.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default +from __future__ import annotations + +import json +import random +from abc import ABC +from builtins import type as type_ +from enum import StrEnum +from typing import Any, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from dify_graph.entities.exc import DefaultValueTypeError +from dify_graph.enums import ErrorStrategy, NodeType + +# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. +_NumberType = Union[int, float] + + +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds (base interval) + retry_enabled: bool = False # whether retry is enabled + + # Exponential backoff configuration + retry_max_interval: int = 10000 # max retry interval in milliseconds (10 seconds) + retry_jitter_ratio: float = 0.1 # jitter ratio (10% of interval) + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + def calculate_retry_interval(self, retry_count: int) -> float: + """ + Calculate retry interval using exponential backoff with jitter. + + Args: + retry_count: Current retry attempt count (0-based) + + Returns: + Retry interval in seconds + + Formula: + interval = base * (2 ** retry_count) + interval += random.uniform(-jitter, jitter) + return min(interval, max_interval) + """ + # Convert base interval to seconds + base_interval = self.retry_interval / 1000.0 + + # Calculate exponential backoff + interval = base_interval * (2 ** retry_count) + + # Add jitter to avoid thundering herd problem + jitter_amount = interval * self.retry_jitter_ratio + jitter = random.uniform(-jitter_amount, jitter_amount) + interval += jitter + + # Ensure minimum interval (at least base interval) + interval = max(interval, base_interval) + + # Cap at maximum interval + max_interval = self.retry_max_interval / 1000.0 + interval = min(interval, max_interval) + + return interval + + +class DefaultValueType(StrEnum): + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" + + +class DefaultValue(BaseModel): + value: Any = None + type: DefaultValueType + key: str + + @staticmethod + def _parse_json(value: str): + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: + """Unified array type validation""" + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + + @model_validator(mode="after") + def validate_value_type(self) -> DefaultValue: + # Type validation configuration + type_validators: dict[DefaultValueType, dict[str, Any]] = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": _NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": _NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator: dict[str, Any] = type_validators.get(self.type, {}) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") + + return self + + +class BaseNodeData(ABC, BaseModel): + # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where + # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. + # `type` therefore accepts downstream string node kinds; unknown node implementations + # are rejected later when the node factory resolves the node registry. + # At that boundary, node-specific fields are still "extra" relative to this shared DTO, + # and persisted templates/workflows also carry undeclared compatibility keys such as + # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive + # here until graph parsing becomes discriminated by node type or those legacy payloads + # are normalized. + model_config = ConfigDict(extra="allow") + + type: NodeType + title: str = "" + desc: str | None = None + version: str = "1" + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None + retry_config: RetryConfig = Field(default_factory=RetryConfig) + + @property + def default_value_dict(self) -> dict[str, Any]: + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} + + def __getitem__(self, key: str) -> Any: + """ + Dict-style access without calling model_dump() on every lookup. + Prefer using model fields and Pydantic's extra storage. + """ + # First, check declared model fields + if key in self.__class__.model_fields: + return getattr(self, key) + + # Then, check undeclared compatibility fields stored in Pydantic's extra dict. + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras[key] + + raise KeyError(key) + + def get(self, key: str, default: Any = None) -> Any: + """ + Dict-style .get() without calling model_dump() on every lookup. + """ + if key in self.__class__.model_fields: + return getattr(self, key) + + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras.get(key, default) + + return default + \ No newline at end of file diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py index d4ee2922ec..9b8af83942 100644 --- a/api/dify_graph/graph_engine/error_handler.py +++ b/api/dify_graph/graph_engine/error_handler.py @@ -1,211 +1,213 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from dify_graph.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from dify_graph.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from dify_graph.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) +""" +Main error handler that coordinates error strategies. +""" + +import logging +import time +from typing import TYPE_CHECKING, final + +from dify_graph.enums import ( + ErrorStrategy as ErrorStrategyEnum, +) +from dify_graph.enums import ( + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.graph import Graph +from dify_graph.graph_events import ( + GraphNodeEventBase, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunRetryEvent, +) +from dify_graph.node_events import NodeRunResult + +if TYPE_CHECKING: + from .domain import GraphExecution + +logger = logging.getLogger(__name__) + + +@final +class ErrorHandler: + """ + Coordinates error handling strategies for node failures. + + This acts as a facade for the various error strategies, + selecting and applying the appropriate strategy based on + node configuration. + """ + + def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: + """ + Initialize the error handler. + + Args: + graph: The workflow graph + graph_execution: The graph execution state + """ + self._graph = graph + self._graph_execution = graph_execution + + def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: + """ + Handle a node failure event. + + Selects and applies the appropriate error strategy based on + the node's configuration. + + Args: + event: The node failure event + + Returns: + Optional new event to process, or None to abort + """ + node = self._graph.nodes[event.node_id] + # Get retry count from NodeExecution + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) + retry_count = node_execution.retry_count + + # First check if retry is configured and not exhausted + if node.retry and retry_count < node.retry_config.max_retries: + result = self._handle_retry(event, retry_count) + if result: + # Retry count will be incremented when NodeRunRetryEvent is handled + return result + + # Apply configured error strategy + strategy = node.error_strategy + + match strategy: + case None: + return self._handle_abort(event) + case ErrorStrategyEnum.FAIL_BRANCH: + return self._handle_fail_branch(event) + case ErrorStrategyEnum.DEFAULT_VALUE: + return self._handle_default_value(event) + + def _handle_abort(self, event: NodeRunFailedEvent): + """ + Handle error by aborting execution. + + This is the default strategy when no other strategy is specified. + It stops the entire graph execution when a node fails. + + Args: + event: The failure event + + Returns: + None - signals abortion + """ + logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) + # Return None to signal that execution should stop + + def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): + """ + Handle error by retrying the node. + + This strategy re-attempts node execution up to a configured + maximum number of retries with exponential backoff intervals. + + Args: + event: The failure event + retry_count: Current retry attempt count + + Returns: + NodeRunRetryEvent if retry should occur, None otherwise + """ + node = self._graph.nodes[event.node_id] + + # Check if we've exceeded max retries + if not node.retry or retry_count >= node.retry_config.max_retries: + return None + + # Calculate retry interval using exponential backoff with jitter + retry_interval = node.retry_config.calculate_retry_interval(retry_count) + time.sleep(retry_interval) + + # Create retry event + return NodeRunRetryEvent( + id=event.id, + node_title=node.title, + node_id=event.node_id, + node_type=event.node_type, + node_run_result=event.node_run_result, + start_at=event.start_at, + error=event.error, + retry_index=retry_count + 1, + ) + + def _handle_fail_branch(self, event: NodeRunFailedEvent): + """ + Handle error by taking the fail branch. + + This strategy converts failures to exceptions and routes execution + through a designated fail-branch edge. + + Args: + event: The failure event + + Returns: + NodeRunExceptionEvent to continue via fail branch + """ + outputs = { + "error_message": event.node_run_result.error, + "error_type": event.node_run_result.error_type, + } + + return NodeRunExceptionEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + start_at=event.start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + inputs=event.node_run_result.inputs, + process_data=event.node_run_result.process_data, + outputs=outputs, + edge_source_handle="fail-branch", + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, + }, + ), + error=event.error, + ) + + def _handle_default_value(self, event: NodeRunFailedEvent): + """ + Handle error by using default values. + + This strategy allows nodes to fail gracefully by providing + predefined default output values. + + Args: + event: The failure event + + Returns: + NodeRunExceptionEvent with default values + """ + node = self._graph.nodes[event.node_id] + + outputs = { + **node.default_value_dict, + "error_message": event.node_run_result.error, + "error_type": event.node_run_result.error_type, + } + + return NodeRunExceptionEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + start_at=event.start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + inputs=event.node_run_result.inputs, + process_data=event.node_run_result.process_data, + outputs=outputs, + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, + }, + ), + error=event.error, + ) + \ No newline at end of file diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index b17c820a80..1bfde52e36 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -257,3 +257,20 @@ class HttpRequestNode(Node[HttpRequestNodeData]): @property def retry(self) -> bool: return self.node_data.retry_config.retry_enabled + + def get_default_config(self) -> dict[str, Any]: + return { + "method": "get", + "authorization": {"type": "no-auth", "config": None}, + "body": {"type": "none", "data": None}, + "headers": [], + "params": [], + "timeout": {"connect": 10, "read": 60, "write": 60}, + "ssl_verify": {"enable": True, "max": None}, + # Use exponential backoff with 100ms base interval + "retry_config": { + "max_retries": http_request_config.ssrf_default_max_retries, + "retry_interval": 100, # Base interval: 100ms (will grow exponentially) + "retry_enabled": True, + }, + } \ No newline at end of file diff --git a/web/app/components/workflow/nodes/http/types.ts b/web/app/components/workflow/nodes/http/types.ts index 0e3fbc0d91..a81fd9016d 100644 --- a/web/app/components/workflow/nodes/http/types.ts +++ b/web/app/components/workflow/nodes/http/types.ts @@ -1,85 +1,90 @@ -import type { CommonNodeType, ValueSelector, Variable } from '@/app/components/workflow/types' - -export enum Method { - get = 'get', - post = 'post', - head = 'head', - patch = 'patch', - put = 'put', - delete = 'delete', -} - -export enum BodyType { - none = 'none', - formData = 'form-data', - xWwwFormUrlencoded = 'x-www-form-urlencoded', - rawText = 'raw-text', - json = 'json', - binary = 'binary', -} - -export type KeyValue = { - id?: string - key: string - value: string - type?: string - file?: ValueSelector -} - -export enum BodyPayloadValueType { - text = 'text', - file = 'file', -} - -export type BodyPayload = { - id?: string - key?: string - type: BodyPayloadValueType - file?: ValueSelector // when type is file - value?: string // when type is text -}[] -export type Body = { - type: BodyType - data: string | BodyPayload // string is deprecated, it would convert to BodyPayload after loaded -} - -export enum AuthorizationType { - none = 'no-auth', - apiKey = 'api-key', -} - -export enum APIType { - basic = 'basic', - bearer = 'bearer', - custom = 'custom', -} - -export type Authorization = { - type: AuthorizationType - config?: { - type: APIType - api_key: string - header?: string - } | null -} - -export type Timeout = { - connect?: number - read?: number - write?: number - max_connect_timeout?: number - max_read_timeout?: number - max_write_timeout?: number -} - -export type HttpNodeType = CommonNodeType & { - variables: Variable[] - method: Method - url: string - headers: string - params: string - body: Body - authorization: Authorization - timeout: Timeout - ssl_verify?: boolean -} +import type { CommonNodeType, ValueSelector, Variable } from '@/app/components/workflow/types' + +export enum Method { + get = 'get', + post = 'post', + head = 'head', + patch = 'patch', + put = 'put', + delete = 'delete', +} + +export enum BodyType { + none = 'none', + formData = 'form-data', + xWwwFormUrlencoded = 'x-www-form-urlencoded', + rawText = 'raw-text', + json = 'json', + binary = 'binary', +} + +export type KeyValue = { + id?: string + key: string + value: string + type?: string + file?: ValueSelector +} + +export enum BodyPayloadValueType { + text = 'text', + file = 'file', +} + +export type BodyPayload = { + id?: string + key?: string + type: BodyPayloadValueType + file?: ValueSelector // when type is file + value?: string // when type is text +}[] +export type Body = { + type: BodyType + data: string | BodyPayload // string is deprecated, it would convert to BodyPayload after loaded +} + +export enum AuthorizationType { + none = 'no-auth', + apiKey = 'api-key', +} + +export enum APIType { + basic = 'basic', + bearer = 'bearer', + custom = 'custom', +} + +export type Authorization = { + type: AuthorizationType + config?: { + type: APIType + api_key: string + header?: string + } | null +} + +export type Timeout = { + connect?: number + read?: number + write?: number + max_connect_timeout?: number + max_read_timeout?: number + max_write_timeout?: number +} + +export type HttpNodeType = CommonNodeType & { + variables: Variable[] + method: Method + url: string + headers: string + params: string + body: Body + authorization: Authorization + timeout: Timeout + ssl_verify?: boolean + retry_config?: { + max_retries: number + retry_interval: number + retry_enabled: boolean + } +} \ No newline at end of file From d24c72d900dca6f5059dde9460ce9c64df784f46 Mon Sep 17 00:00:00 2001 From: Wenbo2105 Date: Mon, 16 Mar 2026 04:38:18 +0800 Subject: [PATCH 2/6] fix:http-request-retry-interval --- api/dify_graph/entities/base_node_data.py | 435 +++++++++++----------- api/dify_graph/nodes/http_request/node.py | 21 +- 2 files changed, 218 insertions(+), 238 deletions(-) diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py index 88146e552b..28f4bfc9ec 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/dify_graph/entities/base_node_data.py @@ -1,219 +1,216 @@ -from __future__ import annotations - -import json -import random -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from dify_graph.entities.exc import DefaultValueTypeError -from dify_graph.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds (base interval) - retry_enabled: bool = False # whether retry is enabled - - # Exponential backoff configuration - retry_max_interval: int = 10000 # max retry interval in milliseconds (10 seconds) - retry_jitter_ratio: float = 0.1 # jitter ratio (10% of interval) - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - def calculate_retry_interval(self, retry_count: int) -> float: - """ - Calculate retry interval using exponential backoff with jitter. - - Args: - retry_count: Current retry attempt count (0-based) - - Returns: - Retry interval in seconds - - Formula: - interval = base * (2 ** retry_count) - interval += random.uniform(-jitter, jitter) - return min(interval, max_interval) - """ - # Convert base interval to seconds - base_interval = self.retry_interval / 1000.0 - - # Calculate exponential backoff - interval = base_interval * (2 ** retry_count) - - # Add jitter to avoid thundering herd problem - jitter_amount = interval * self.retry_jitter_ratio - jitter = random.uniform(-jitter_amount, jitter_amount) - interval += jitter - - # Ensure minimum interval (at least base interval) - interval = max(interval, base_interval) - - # Cap at maximum interval - max_interval = self.retry_max_interval / 1000.0 - interval = min(interval, max_interval) - - return interval - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default - \ No newline at end of file +from __future__ import annotations + +import json +import random +from abc import ABC +from builtins import type as type_ +from enum import StrEnum +from typing import Any, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from dify_graph.entities.exc import DefaultValueTypeError +from dify_graph.enums import ErrorStrategy, NodeType + +# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. +_NumberType = Union[int, float] + + +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds (base interval) + retry_enabled: bool = False # whether retry is enabled + + # Exponential backoff configuration + retry_max_interval: int = 10000 # max retry interval in milliseconds (10 seconds) + retry_jitter_ratio: float = 0.1 # jitter ratio (10% of interval) + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + def calculate_retry_interval(self, retry_count: int) -> float: + """ + Calculate retry interval using exponential backoff with jitter. + + Args: + retry_count: Current retry attempt count (0-based) + + Returns: + Retry interval in seconds + + Formula: + interval = base * (2 ** retry_count) + interval += random.uniform(-jitter, jitter) + return min(interval, max_interval) + """ + # Convert base interval to seconds + base_interval = self.retry_interval / 1000.0 + + # Calculate exponential backoff + interval = base_interval * (2 ** retry_count) + + # Add jitter to avoid thundering herd problem + jitter_amount = interval * self.retry_jitter_ratio + jitter = random.uniform(-jitter_amount, jitter_amount) + interval += jitter + + # Cap at maximum interval + max_interval = self.retry_max_interval / 1000.0 + interval = min(interval, max_interval) + + return interval + + +class DefaultValueType(StrEnum): + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" + + +class DefaultValue(BaseModel): + value: Any = None + type: DefaultValueType + key: str + + @staticmethod + def _parse_json(value: str): + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: + """Unified array type validation""" + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + + @model_validator(mode="after") + def validate_value_type(self) -> DefaultValue: + # Type validation configuration + type_validators: dict[DefaultValueType, dict[str, Any]] = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": _NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": _NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator: dict[str, Any] = type_validators.get(self.type, {}) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") + + return self + + +class BaseNodeData(ABC, BaseModel): + # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where + # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. + # `type` therefore accepts downstream string node kinds; unknown node implementations + # are rejected later when the node factory resolves the node registry. + # At that boundary, node-specific fields are still "extra" relative to this shared DTO, + # and persisted templates/workflows also carry undeclared compatibility keys such as + # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive + # here until graph parsing becomes discriminated by node type or those legacy payloads + # are normalized. + model_config = ConfigDict(extra="allow") + + type: NodeType + title: str = "" + desc: str | None = None + version: str = "1" + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None + retry_config: RetryConfig = Field(default_factory=RetryConfig) + + @property + def default_value_dict(self) -> dict[str, Any]: + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} + + def __getitem__(self, key: str) -> Any: + """ + Dict-style access without calling model_dump() on every lookup. + Prefer using model fields and Pydantic's extra storage. + """ + # First, check declared model fields + if key in self.__class__.model_fields: + return getattr(self, key) + + # Then, check undeclared compatibility fields stored in Pydantic's extra dict. + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras[key] + + raise KeyError(key) + + def get(self, key: str, default: Any = None) -> Any: + """ + Dict-style .get() without calling model_dump() on every lookup. + """ + if key in self.__class__.model_fields: + return getattr(self, key) + + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras.get(key, default) + + return default + diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 1bfde52e36..3e07ea6bab 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -84,7 +84,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): }, "retry_config": { "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 0.5 * (2**2), + "retry_interval": 100, # Base interval: 100ms (will grow exponentially) "retry_enabled": True, }, } @@ -256,21 +256,4 @@ class HttpRequestNode(Node[HttpRequestNodeData]): @property def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - def get_default_config(self) -> dict[str, Any]: - return { - "method": "get", - "authorization": {"type": "no-auth", "config": None}, - "body": {"type": "none", "data": None}, - "headers": [], - "params": [], - "timeout": {"connect": 10, "read": 60, "write": 60}, - "ssl_verify": {"enable": True, "max": None}, - # Use exponential backoff with 100ms base interval - "retry_config": { - "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 100, # Base interval: 100ms (will grow exponentially) - "retry_enabled": True, - }, - } \ No newline at end of file + return self.node_data.retry_config.retry_enabled \ No newline at end of file From 91a7a7d1d151c0c697e1bd07a415f19e1c7cbef0 Mon Sep 17 00:00:00 2001 From: Wenbo2105 Date: Mon, 16 Mar 2026 04:46:54 +0800 Subject: [PATCH 3/6] fix(base_node_data): code check --- api/dify_graph/entities/base_node_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py index 28f4bfc9ec..8f44f81120 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/dify_graph/entities/base_node_data.py @@ -61,7 +61,8 @@ class RetryConfig(BaseModel): max_interval = self.retry_max_interval / 1000.0 interval = min(interval, max_interval) - return interval + # Ensure non-negative interval (defensive programming) + return max(0.0, interval) class DefaultValueType(StrEnum): @@ -214,3 +215,4 @@ class BaseNodeData(ABC, BaseModel): return default + From 88ce4a823772a7f50f2197134a7442a4bb1fd789 Mon Sep 17 00:00:00 2001 From: Wenbo2105 Date: Mon, 16 Mar 2026 19:22:59 +0800 Subject: [PATCH 4/6] docs(base_node_data): remove empty lines&add safety comments for jitter --- api/dify_graph/entities/base_node_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py index 8f44f81120..46a25c0331 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/dify_graph/entities/base_node_data.py @@ -53,6 +53,7 @@ class RetryConfig(BaseModel): interval = base_interval * (2 ** retry_count) # Add jitter to avoid thundering herd problem + # nosec: B311 - random.uniform is used for jitter, not cryptographic purposes jitter_amount = interval * self.retry_jitter_ratio jitter = random.uniform(-jitter_amount, jitter_amount) interval += jitter @@ -210,9 +211,12 @@ class BaseNodeData(ABC, BaseModel): extras = getattr(self, "__pydantic_extra__", None) if extras is None: extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: + if extras is not None: return extras.get(key, default) return default + + + From 009eb249984ded0fba70dccdc21eb08ac5045923 Mon Sep 17 00:00:00 2001 From: Wenbo2105 Date: Mon, 16 Mar 2026 23:39:13 +0800 Subject: [PATCH 5/6] Fix:MockHttpRequestNode._run() return type mismatch (Generator -> NodeRunResult) --- .../workflow/graph_engine/test_mock_nodes.py | 1742 ++++++++--------- 1 file changed, 869 insertions(+), 873 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index e117f81ff9..73b8224831 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -1,873 +1,869 @@ -""" -Mock node implementations for testing. - -This module provides mock implementations of nodes that require third-party services, -allowing tests to run without external dependencies. -""" - -import time -from collections.abc import Generator, Mapping -from typing import TYPE_CHECKING, Any, Optional -from unittest.mock import MagicMock - -from core.model_manager import ModelInstance -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.code import CodeNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode -from dify_graph.nodes.http_request import HttpRequestNode -from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory -from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.nodes.question_classifier import QuestionClassifierNode -from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) -from dify_graph.nodes.tool import ToolNode - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - from .test_mock_config import MockConfig - - -class _TestJinja2Renderer(Jinja2TemplateRenderer): - """Simple Jinja2 renderer for tests (avoids code executor).""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - from jinja2 import Template as _Jinja2Template - - try: - return _Jinja2Template(template).render(**variables) - except Exception as exc: # pragma: no cover - pass through as contract error - raise TemplateRenderError(str(exc)) from exc - - -class MockNodeMixin: - """Mixin providing common mock functionality.""" - - def __init__( - self, - id: str, - config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - mock_config: Optional["MockConfig"] = None, - **kwargs: Any, - ): - if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): - kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) - kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) - kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) - # LLM-like nodes now require an http_client; provide a mock by default for tests. - kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - - # Ensure TemplateTransformNode receives a renderer now required by constructor - if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) - - # Provide default tool_file_manager_factory for ToolNode subclasses - from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles - - if isinstance(self, _ToolNode): - kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) - - if isinstance(self, AgentNode): - presentation_provider = MagicMock() - presentation_provider.get_icon.return_value = None - kwargs.setdefault("strategy_resolver", MagicMock()) - kwargs.setdefault("presentation_provider", presentation_provider) - kwargs.setdefault("runtime_support", MagicMock()) - kwargs.setdefault("message_transformer", MagicMock()) - - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - **kwargs, - ) - self.mock_config = mock_config - - def _get_mock_outputs(self, default_outputs: dict[str, Any]) -> dict[str, Any]: - """Get mock outputs for this node.""" - if not self.mock_config: - return default_outputs - - # Check for node-specific configuration - node_config = self.mock_config.get_node_config(self._node_id) - if node_config and node_config.outputs: - return node_config.outputs - - # Check for custom handler - if node_config and node_config.custom_handler: - return node_config.custom_handler(self) - - return default_outputs - - def _should_simulate_error(self) -> str | None: - """Check if this node should simulate an error.""" - if not self.mock_config: - return None - - node_config = self.mock_config.get_node_config(self._node_id) - if node_config: - return node_config.error - - return None - - def _simulate_delay(self) -> None: - """Simulate execution delay if configured.""" - if not self.mock_config or not self.mock_config.simulate_delays: - return - - node_config = self.mock_config.get_node_config(self._node_id) - if node_config and node_config.delay > 0: - time.sleep(node_config.delay) - - -class MockLLMNode(MockNodeMixin, LLMNode): - """Mock implementation of LLMNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock LLM node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - default_response = self.mock_config.default_llm_response if self.mock_config else "Mocked LLM response" - outputs = self._get_mock_outputs( - { - "text": default_response, - "usage": { - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15, - }, - "finish_reason": "stop", - } - ) - - # Simulate streaming if text output exists - if "text" in outputs: - text = str(outputs["text"]) - # Split text into words and stream with spaces between them - # To match test expectation of text.count(" ") + 2 chunks - words = text.split(" ") - for i, word in enumerate(words): - # Add space before word (except for first word) to reconstruct text properly - if i > 0: - chunk = " " + word - else: - chunk = word - - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=chunk, - is_final=False, - ) - - # Send final chunk - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Create mock usage with all required fields - usage = LLMUsage.empty_usage() - usage.prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 10) - usage.completion_tokens = outputs.get("usage", {}).get("completion_tokens", 5) - usage.total_tokens = outputs.get("usage", {}).get("total_tokens", 15) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"mock": "inputs"}, - process_data={ - "model_mode": "chat", - "prompts": [], - "usage": outputs.get("usage", {}), - "finish_reason": outputs.get("finish_reason", "stop"), - "model_provider": "mock_provider", - "model_name": "mock_model", - }, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0, - WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", - }, - llm_usage=usage, - ) - ) - - -class MockAgentNode(MockNodeMixin, AgentNode): - """Mock implementation of AgentNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock agent node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - default_response = self.mock_config.default_agent_response if self.mock_config else "Mocked agent response" - outputs = self._get_mock_outputs( - { - "output": default_response, - "files": [], - } - ) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"mock": "inputs"}, - process_data={ - "agent_log": "Mock agent executed successfully", - }, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.AGENT_LOG: "Mock agent log", - }, - ) - ) - - -class MockToolNode(MockNodeMixin, ToolNode): - """Mock implementation of ToolNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock tool node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - default_response = ( - self.mock_config.default_tool_response if self.mock_config else {"result": "mocked tool output"} - ) - outputs = self._get_mock_outputs(default_response) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"mock": "inputs"}, - process_data={ - "tool_name": "mock_tool", - "tool_parameters": {}, - }, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOOL_INFO: { - "tool_name": "mock_tool", - "tool_label": "Mock Tool", - }, - }, - ) - ) - - -class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode): - """Mock implementation of KnowledgeRetrievalNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock knowledge retrieval node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - default_response = ( - self.mock_config.default_retrieval_response if self.mock_config else "Mocked retrieval content" - ) - outputs = self._get_mock_outputs( - { - "result": [ - { - "content": default_response, - "score": 0.95, - "metadata": {"source": "mock_source"}, - } - ], - } - ) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"query": "mock query"}, - process_data={ - "retrieval_method": "mock", - "documents_count": 1, - }, - outputs=outputs, - ) - ) - - -class MockHttpRequestNode(MockNodeMixin, HttpRequestNode): - """Mock implementation of HttpRequestNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock HTTP request node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - default_response = ( - self.mock_config.default_http_response - if self.mock_config - else { - "status_code": 200, - "body": "mocked response", - "headers": {}, - } - ) - outputs = self._get_mock_outputs(default_response) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"url": "http://mock.url", "method": "GET"}, - process_data={ - "request_url": "http://mock.url", - "request_method": "GET", - }, - outputs=outputs, - ) - ) - - -class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode): - """Mock implementation of QuestionClassifierNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock question classifier node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - default to first class - outputs = self._get_mock_outputs( - { - "class_name": "class_1", - } - ) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"query": "mock query"}, - process_data={ - "classification": outputs.get("class_name", "class_1"), - }, - outputs=outputs, - edge_source_handle=outputs.get("class_name", "class_1"), # Branch based on classification - ) - ) - - -class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode): - """Mock implementation of ParameterExtractorNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock parameter extractor node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - outputs = self._get_mock_outputs( - { - "parameters": { - "param1": "value1", - "param2": "value2", - }, - } - ) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"text": "mock text"}, - process_data={ - "extracted_parameters": outputs.get("parameters", {}), - }, - outputs=outputs, - ) - ) - - -class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): - """Mock implementation of DocumentExtractorNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> Generator: - """Execute mock document extractor node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - process_data={}, - error_type="MockError", - ) - ) - return - - # Get mock response - outputs = self._get_mock_outputs( - { - "text": "Mocked extracted document content", - "metadata": { - "pages": 1, - "format": "mock", - }, - } - ) - - # Send completion event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"file": "mock_file.pdf"}, - process_data={ - "extraction_method": "mock", - }, - outputs=outputs, - ) - ) - - -from dify_graph.nodes.iteration import IterationNode -from dify_graph.nodes.loop import LoopNode - - -class MockIterationNode(MockNodeMixin, IterationNode): - """Mock implementation of IterationNode that preserves mock configuration.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _create_graph_engine(self, index: int, item: Any): - """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" - # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState - - # Import our MockNodeFactory instead of DifyNodeFactory - from .test_mock_factory import MockNodeFactory - - # Create GraphInitParams from node attributes - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - # Create a deep copy of the variable pool for each iteration - variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - # append iteration variable (item, index) to variable pool - variable_pool_copy.add([self._node_id, "index"], index) - variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) - - # Create a MockNodeFactory with the same mock_config - node_factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - mock_config=self.mock_config, # Pass the mock configuration - ) - - # Initialize the iteration graph with the mock node factory - iteration_graph = Graph.init( - graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id - ) - - if not iteration_graph: - from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError - - raise IterationGraphNotFoundError("iteration graph not found") - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( - workflow_id=self.workflow_id, - graph=iteration_graph, - graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), - ) - - return graph_engine - - -class MockLoopNode(MockNodeMixin, LoopNode): - """Mock implementation of LoopNode that preserves mock configuration.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _create_graph_engine(self, start_at, root_node_id: str): - """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" - # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState - - # Import our MockNodeFactory instead of DifyNodeFactory - from .test_mock_factory import MockNodeFactory - - # Create GraphInitParams from node attributes - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - - # Create a MockNodeFactory with the same mock_config - node_factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - mock_config=self.mock_config, # Pass the mock configuration - ) - - # Initialize the loop graph with the mock node factory - loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id) - - if not loop_graph: - raise ValueError("loop graph not found") - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( - workflow_id=self.workflow_id, - graph=loop_graph, - graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), - ) - - return graph_engine - - -class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode): - """Mock implementation of TemplateTransformNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> NodeRunResult: - """Execute mock template transform node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - error_type="MockError", - ) - - # Get variables from the node data - variables: dict[str, Any] = {} - if hasattr(self._node_data, "variables"): - for variable_selector in self._node_data.variables: - variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - variables[variable_name] = value.to_object() if value else None - - # Check if we have custom mock outputs configured - if self.mock_config: - node_config = self.mock_config.get_node_config(self._node_id) - if node_config and node_config.outputs: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=node_config.outputs, - ) - - # Try to actually process the template using Jinja2 directly - try: - if hasattr(self._node_data, "template"): - # Import jinja2 here to avoid dependency issues - from jinja2 import Template - - template = Template(self._node_data.template) - result_text = template.render(**variables) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result_text} - ) - except Exception as e: - # If direct Jinja2 fails, try CodeExecutor as fallback - try: - from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage - - if hasattr(self._node_data, "template"): - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs={"output": result["result"]}, - ) - except Exception: - # Both methods failed, fall back to default mock output - pass - - # Fall back to default mock output - default_response = ( - self.mock_config.default_template_transform_response if self.mock_config else "mocked template output" - ) - default_outputs = {"output": default_response} - outputs = self._get_mock_outputs(default_outputs) - - # Return result - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=outputs, - ) - - -class MockCodeNode(MockNodeMixin, CodeNode): - """Mock implementation of CodeNode for testing.""" - - @classmethod - def version(cls) -> str: - """Return the version of this mock node.""" - return "1" - - def _run(self) -> NodeRunResult: - """Execute mock code node.""" - # Simulate delay if configured - self._simulate_delay() - - # Check for simulated error - error = self._should_simulate_error() - if error: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - inputs={}, - error_type="MockError", - ) - - # Get mock outputs - use configured outputs or default based on output schema - default_outputs = {} - if hasattr(self._node_data, "outputs") and self._node_data.outputs: - # Generate default outputs based on schema - for output_name, output_config in self._node_data.outputs.items(): - if output_config.type == "string": - default_outputs[output_name] = f"mocked_{output_name}" - elif output_config.type == "number": - default_outputs[output_name] = 42 - elif output_config.type == "object": - default_outputs[output_name] = {"key": "value"} - elif output_config.type == "array[string]": - default_outputs[output_name] = ["item1", "item2"] - elif output_config.type == "array[number]": - default_outputs[output_name] = [1, 2, 3] - elif output_config.type == "array[object]": - default_outputs[output_name] = [{"key": "value1"}, {"key": "value2"}] - else: - # Default output when no schema is defined - default_outputs = ( - self.mock_config.default_code_response - if self.mock_config - else {"result": "mocked code execution result"} - ) - - outputs = self._get_mock_outputs(default_outputs) - - # Return result - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={}, - outputs=outputs, - ) +""" +Mock node implementations for testing. + +This module provides mock implementations of nodes that require third-party services, +allowing tests to run without external dependencies. +""" + +import time +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Optional +from unittest.mock import MagicMock + +from core.model_manager import ModelInstance +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.code import CodeNode +from dify_graph.nodes.document_extractor import DocumentExtractorNode +from dify_graph.nodes.http_request import HttpRequestNode +from dify_graph.nodes.llm import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.question_classifier import QuestionClassifierNode +from dify_graph.nodes.template_transform import TemplateTransformNode +from dify_graph.nodes.template_transform.template_renderer import ( + Jinja2TemplateRenderer, + TemplateRenderError, +) +from dify_graph.nodes.tool import ToolNode + +if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState + + from .test_mock_config import MockConfig + + +class _TestJinja2Renderer(Jinja2TemplateRenderer): + """Simple Jinja2 renderer for tests (avoids code executor).""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + from jinja2 import Template as _Jinja2Template + + try: + return _Jinja2Template(template).render(**variables) + except Exception as exc: # pragma: no cover - pass through as contract error + raise TemplateRenderError(str(exc)) from exc + + +class MockNodeMixin: + """Mixin providing common mock functionality.""" + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + mock_config: Optional["MockConfig"] = None, + **kwargs: Any, + ): + if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): + kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) + kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) + kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # LLM-like nodes now require an http_client; provide a mock by default for tests. + kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) + + # Ensure TemplateTransformNode receives a renderer now required by constructor + if isinstance(self, TemplateTransformNode): + kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + + # Provide default tool_file_manager_factory for ToolNode subclasses + from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + + if isinstance(self, _ToolNode): + kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + + if isinstance(self, AgentNode): + presentation_provider = MagicMock() + presentation_provider.get_icon.return_value = None + kwargs.setdefault("strategy_resolver", MagicMock()) + kwargs.setdefault("presentation_provider", presentation_provider) + kwargs.setdefault("runtime_support", MagicMock()) + kwargs.setdefault("message_transformer", MagicMock()) + + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + self.mock_config = mock_config + + def _get_mock_outputs(self, default_outputs: dict[str, Any]) -> dict[str, Any]: + """Get mock outputs for this node.""" + if not self.mock_config: + return default_outputs + + # Check for node-specific configuration + node_config = self.mock_config.get_node_config(self._node_id) + if node_config and node_config.outputs: + return node_config.outputs + + # Check for custom handler + if node_config and node_config.custom_handler: + return node_config.custom_handler(self) + + return default_outputs + + def _should_simulate_error(self) -> str | None: + """Check if this node should simulate an error.""" + if not self.mock_config: + return None + + node_config = self.mock_config.get_node_config(self._node_id) + if node_config: + return node_config.error + + return None + + def _simulate_delay(self) -> None: + """Simulate execution delay if configured.""" + if not self.mock_config or not self.mock_config.simulate_delays: + return + + node_config = self.mock_config.get_node_config(self._node_id) + if node_config and node_config.delay > 0: + time.sleep(node_config.delay) + + +class MockLLMNode(MockNodeMixin, LLMNode): + """Mock implementation of LLMNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock LLM node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response + default_response = self.mock_config.default_llm_response if self.mock_config else "Mocked LLM response" + outputs = self._get_mock_outputs( + { + "text": default_response, + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + "finish_reason": "stop", + } + ) + + # Simulate streaming if text output exists + if "text" in outputs: + text = str(outputs["text"]) + # Split text into words and stream with spaces between them + # To match test expectation of text.count(" ") + 2 chunks + words = text.split(" ") + for i, word in enumerate(words): + # Add space before word (except for first word) to reconstruct text properly + if i > 0: + chunk = " " + word + else: + chunk = word + + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=chunk, + is_final=False, + ) + + # Send final chunk + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + + # Create mock usage with all required fields + usage = LLMUsage.empty_usage() + usage.prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 10) + usage.completion_tokens = outputs.get("usage", {}).get("completion_tokens", 5) + usage.total_tokens = outputs.get("usage", {}).get("total_tokens", 15) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"mock": "inputs"}, + process_data={ + "model_mode": "chat", + "prompts": [], + "usage": outputs.get("usage", {}), + "finish_reason": outputs.get("finish_reason", "stop"), + "model_provider": "mock_provider", + "model_name": "mock_model", + }, + outputs=outputs, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0, + WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", + }, + llm_usage=usage, + ) + ) + + +class MockAgentNode(MockNodeMixin, AgentNode): + """Mock implementation of AgentNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock agent node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response + default_response = self.mock_config.default_agent_response if self.mock_config else "Mocked agent response" + outputs = self._get_mock_outputs( + { + "output": default_response, + "files": [], + } + ) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"mock": "inputs"}, + process_data={ + "agent_log": "Mock agent executed successfully", + }, + outputs=outputs, + metadata={ + WorkflowNodeExecutionMetadataKey.AGENT_LOG: "Mock agent log", + }, + ) + ) + + +class MockToolNode(MockNodeMixin, ToolNode): + """Mock implementation of ToolNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock tool node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response + default_response = ( + self.mock_config.default_tool_response if self.mock_config else {"result": "mocked tool output"} + ) + outputs = self._get_mock_outputs(default_response) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"mock": "inputs"}, + process_data={ + "tool_name": "mock_tool", + "tool_parameters": {}, + }, + outputs=outputs, + metadata={ + WorkflowNodeExecutionMetadataKey.TOOL_INFO: { + "tool_name": "mock_tool", + "tool_label": "Mock Tool", + }, + }, + ) + ) + + +class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode): + """Mock implementation of KnowledgeRetrievalNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock knowledge retrieval node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response + default_response = ( + self.mock_config.default_retrieval_response if self.mock_config else "Mocked retrieval content" + ) + outputs = self._get_mock_outputs( + { + "result": [ + { + "content": default_response, + "score": 0.95, + "metadata": {"source": "mock_source"}, + } + ], + } + ) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"query": "mock query"}, + process_data={ + "retrieval_method": "mock", + "documents_count": 1, + }, + outputs=outputs, + ) + ) + + +class MockHttpRequestNode(MockNodeMixin, HttpRequestNode): + """Mock implementation of HttpRequestNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> NodeRunResult: + """Execute mock HTTP request node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + + # Get mock response + default_response = ( + self.mock_config.default_http_response + if self.mock_config + else { + "status_code": 200, + "body": "mocked response", + "headers": {}, + } + ) + outputs = self._get_mock_outputs(default_response) + + # Return result + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"url": "http://mock.url", "method": "GET"}, + process_data={ + "request_url": "http://mock.url", + "request_method": "GET", + }, + outputs=outputs, + ) + + +class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode): + """Mock implementation of QuestionClassifierNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock question classifier node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response - default to first class + outputs = self._get_mock_outputs( + { + "class_name": "class_1", + } + ) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"query": "mock query"}, + process_data={ + "classification": outputs.get("class_name", "class_1"), + }, + outputs=outputs, + edge_source_handle=outputs.get("class_name", "class_1"), # Branch based on classification + ) + ) + + +class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode): + """Mock implementation of ParameterExtractorNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock parameter extractor node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response + outputs = self._get_mock_outputs( + { + "parameters": { + "param1": "value1", + "param2": "value2", + }, + } + ) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"text": "mock text"}, + process_data={ + "extracted_parameters": outputs.get("parameters", {}), + }, + outputs=outputs, + ) + ) + + +class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): + """Mock implementation of DocumentExtractorNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> Generator: + """Execute mock document extractor node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + process_data={}, + error_type="MockError", + ) + ) + return + + # Get mock response + outputs = self._get_mock_outputs( + { + "text": "Mocked extracted document content", + "metadata": { + "pages": 1, + "format": "mock", + }, + } + ) + + # Send completion event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"file": "mock_file.pdf"}, + process_data={ + "extraction_method": "mock", + }, + outputs=outputs, + ) + ) + + +from dify_graph.nodes.iteration import IterationNode +from dify_graph.nodes.loop import LoopNode + + +class MockIterationNode(MockNodeMixin, IterationNode): + """Mock implementation of IterationNode that preserves mock configuration.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _create_graph_engine(self, index: int, item: Any): + """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" + # Import dependencies + from dify_graph.entities import GraphInitParams + from dify_graph.graph import Graph + from dify_graph.graph_engine import GraphEngine, GraphEngineConfig + from dify_graph.graph_engine.command_channels import InMemoryChannel + from dify_graph.runtime import GraphRuntimeState + + # Import our MockNodeFactory instead of DifyNodeFactory + from .test_mock_factory import MockNodeFactory + + # Create GraphInitParams from node attributes + graph_init_params = GraphInitParams( + workflow_id=self.workflow_id, + graph_config=self.graph_config, + run_context=self.run_context, + call_depth=self.workflow_call_depth, + ) + + # Create a deep copy of the variable pool for each iteration + variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) + + # append iteration variable (item, index) to variable pool + variable_pool_copy.add([self._node_id, "index"], index) + variable_pool_copy.add([self._node_id, "item"], item) + + # Create a new GraphRuntimeState for this iteration + graph_runtime_state_copy = GraphRuntimeState( + variable_pool=variable_pool_copy, + start_at=self.graph_runtime_state.start_at, + total_tokens=0, + node_run_steps=0, + ) + + # Create a MockNodeFactory with the same mock_config + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state_copy, + mock_config=self.mock_config, # Pass the mock configuration + ) + + # Initialize the iteration graph with the mock node factory + iteration_graph = Graph.init( + graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id + ) + + if not iteration_graph: + from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError + + raise IterationGraphNotFoundError("iteration graph not found") + + # Create a new GraphEngine for this iteration + graph_engine = GraphEngine( + workflow_id=self.workflow_id, + graph=iteration_graph, + graph_runtime_state=graph_runtime_state_copy, + command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs + config=GraphEngineConfig(), + ) + + return graph_engine + + +class MockLoopNode(MockNodeMixin, LoopNode): + """Mock implementation of LoopNode that preserves mock configuration.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _create_graph_engine(self, start_at, root_node_id: str): + """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" + # Import dependencies + from dify_graph.entities import GraphInitParams + from dify_graph.graph import Graph + from dify_graph.graph_engine import GraphEngine, GraphEngineConfig + from dify_graph.graph_engine.command_channels import InMemoryChannel + from dify_graph.runtime import GraphRuntimeState + + # Import our MockNodeFactory instead of DifyNodeFactory + from .test_mock_factory import MockNodeFactory + + # Create GraphInitParams from node attributes + graph_init_params = GraphInitParams( + workflow_id=self.workflow_id, + graph_config=self.graph_config, + run_context=self.run_context, + call_depth=self.workflow_call_depth, + ) + + # Create a new GraphRuntimeState for this iteration + graph_runtime_state_copy = GraphRuntimeState( + variable_pool=self.graph_runtime_state.variable_pool, + start_at=start_at.timestamp(), + ) + + # Create a MockNodeFactory with the same mock_config + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state_copy, + mock_config=self.mock_config, # Pass the mock configuration + ) + + # Initialize the loop graph with the mock node factory + loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id) + + if not loop_graph: + raise ValueError("loop graph not found") + + # Create a new GraphEngine for this iteration + graph_engine = GraphEngine( + workflow_id=self.workflow_id, + graph=loop_graph, + graph_runtime_state=graph_runtime_state_copy, + command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs + config=GraphEngineConfig(), + ) + + return graph_engine + + +class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode): + """Mock implementation of TemplateTransformNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> NodeRunResult: + """Execute mock template transform node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + error_type="MockError", + ) + + # Get variables from the node data + variables: dict[str, Any] = {} + if hasattr(self._node_data, "variables"): + for variable_selector in self._node_data.variables: + variable_name = variable_selector.variable + value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + variables[variable_name] = value.to_object() if value else None + + # Check if we have custom mock outputs configured + if self.mock_config: + node_config = self.mock_config.get_node_config(self._node_id) + if node_config and node_config.outputs: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=node_config.outputs, + ) + + # Try to actually process the template using Jinja2 directly + try: + if hasattr(self._node_data, "template"): + # Import jinja2 here to avoid dependency issues + from jinja2 import Template + + template = Template(self._node_data.template) + result_text = template.render(**variables) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result_text} + ) + except Exception as e: + # If direct Jinja2 fails, try CodeExecutor as fallback + try: + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage + + if hasattr(self._node_data, "template"): + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs={"output": result["result"]}, + ) + except Exception: + # Both methods failed, fall back to default mock output + pass + + # Fall back to default mock output + default_response = ( + self.mock_config.default_template_transform_response if self.mock_config else "mocked template output" + ) + default_outputs = {"output": default_response} + outputs = self._get_mock_outputs(default_outputs) + + # Return result + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=outputs, + ) + + +class MockCodeNode(MockNodeMixin, CodeNode): + """Mock implementation of CodeNode for testing.""" + + @classmethod + def version(cls) -> str: + """Return the version of this mock node.""" + return "1" + + def _run(self) -> NodeRunResult: + """Execute mock code node.""" + # Simulate delay if configured + self._simulate_delay() + + # Check for simulated error + error = self._should_simulate_error() + if error: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + inputs={}, + error_type="MockError", + ) + + # Get mock outputs - use configured outputs or default based on output schema + default_outputs = {} + if hasattr(self._node_data, "outputs") and self._node_data.outputs: + # Generate default outputs based on schema + for output_name, output_config in self._node_data.outputs.items(): + if output_config.type == "string": + default_outputs[output_name] = f"mocked_{output_name}" + elif output_config.type == "number": + default_outputs[output_name] = 42 + elif output_config.type == "object": + default_outputs[output_name] = {"key": "value"} + elif output_config.type == "array[string]": + default_outputs[output_name] = ["item1", "item2"] + elif output_config.type == "array[number]": + default_outputs[output_name] = [1, 2, 3] + elif output_config.type == "array[object]": + default_outputs[output_name] = [{"key": "value1"}, {"key": "value2"}] + else: + # Default output when no schema is defined + default_outputs = ( + self.mock_config.default_code_response + if self.mock_config + else {"result": "mocked code execution result"} + ) + + outputs = self._get_mock_outputs(default_outputs) + + # Return result + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + outputs=outputs, + ) + \ No newline at end of file From b163a39b663d263cf8c55130bde5eee945caee53 Mon Sep 17 00:00:00 2001 From: Wenbo2105 Date: Wed, 18 Mar 2026 21:18:23 +0800 Subject: [PATCH 6/6] fix: fixed the wrong CRLF files --- api/dify_graph/graph_engine/error_handler.py | 424 +++++++++--------- .../components/workflow/nodes/http/types.ts | 178 ++++---- 2 files changed, 301 insertions(+), 301 deletions(-) diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py index 9b8af83942..e6e786687a 100644 --- a/api/dify_graph/graph_engine/error_handler.py +++ b/api/dify_graph/graph_engine/error_handler.py @@ -1,213 +1,213 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from dify_graph.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from dify_graph.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from dify_graph.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with exponential backoff intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Calculate retry interval using exponential backoff with jitter - retry_interval = node.retry_config.calculate_retry_interval(retry_count) - time.sleep(retry_interval) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) +""" +Main error handler that coordinates error strategies. +""" + +import logging +import time +from typing import TYPE_CHECKING, final + +from dify_graph.enums import ( + ErrorStrategy as ErrorStrategyEnum, +) +from dify_graph.enums import ( + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.graph import Graph +from dify_graph.graph_events import ( + GraphNodeEventBase, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunRetryEvent, +) +from dify_graph.node_events import NodeRunResult + +if TYPE_CHECKING: + from .domain import GraphExecution + +logger = logging.getLogger(__name__) + + +@final +class ErrorHandler: + """ + Coordinates error handling strategies for node failures. + + This acts as a facade for the various error strategies, + selecting and applying the appropriate strategy based on + node configuration. + """ + + def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: + """ + Initialize the error handler. + + Args: + graph: The workflow graph + graph_execution: The graph execution state + """ + self._graph = graph + self._graph_execution = graph_execution + + def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: + """ + Handle a node failure event. + + Selects and applies the appropriate error strategy based on + the node's configuration. + + Args: + event: The node failure event + + Returns: + Optional new event to process, or None to abort + """ + node = self._graph.nodes[event.node_id] + # Get retry count from NodeExecution + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) + retry_count = node_execution.retry_count + + # First check if retry is configured and not exhausted + if node.retry and retry_count < node.retry_config.max_retries: + result = self._handle_retry(event, retry_count) + if result: + # Retry count will be incremented when NodeRunRetryEvent is handled + return result + + # Apply configured error strategy + strategy = node.error_strategy + + match strategy: + case None: + return self._handle_abort(event) + case ErrorStrategyEnum.FAIL_BRANCH: + return self._handle_fail_branch(event) + case ErrorStrategyEnum.DEFAULT_VALUE: + return self._handle_default_value(event) + + def _handle_abort(self, event: NodeRunFailedEvent): + """ + Handle error by aborting execution. + + This is the default strategy when no other strategy is specified. + It stops the entire graph execution when a node fails. + + Args: + event: The failure event + + Returns: + None - signals abortion + """ + logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) + # Return None to signal that execution should stop + + def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): + """ + Handle error by retrying the node. + + This strategy re-attempts node execution up to a configured + maximum number of retries with exponential backoff intervals. + + Args: + event: The failure event + retry_count: Current retry attempt count + + Returns: + NodeRunRetryEvent if retry should occur, None otherwise + """ + node = self._graph.nodes[event.node_id] + + # Check if we've exceeded max retries + if not node.retry or retry_count >= node.retry_config.max_retries: + return None + + # Calculate retry interval using exponential backoff with jitter + retry_interval = node.retry_config.calculate_retry_interval(retry_count) + time.sleep(retry_interval) + + # Create retry event + return NodeRunRetryEvent( + id=event.id, + node_title=node.title, + node_id=event.node_id, + node_type=event.node_type, + node_run_result=event.node_run_result, + start_at=event.start_at, + error=event.error, + retry_index=retry_count + 1, + ) + + def _handle_fail_branch(self, event: NodeRunFailedEvent): + """ + Handle error by taking the fail branch. + + This strategy converts failures to exceptions and routes execution + through a designated fail-branch edge. + + Args: + event: The failure event + + Returns: + NodeRunExceptionEvent to continue via fail branch + """ + outputs = { + "error_message": event.node_run_result.error, + "error_type": event.node_run_result.error_type, + } + + return NodeRunExceptionEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + start_at=event.start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + inputs=event.node_run_result.inputs, + process_data=event.node_run_result.process_data, + outputs=outputs, + edge_source_handle="fail-branch", + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, + }, + ), + error=event.error, + ) + + def _handle_default_value(self, event: NodeRunFailedEvent): + """ + Handle error by using default values. + + This strategy allows nodes to fail gracefully by providing + predefined default output values. + + Args: + event: The failure event + + Returns: + NodeRunExceptionEvent with default values + """ + node = self._graph.nodes[event.node_id] + + outputs = { + **node.default_value_dict, + "error_message": event.node_run_result.error, + "error_type": event.node_run_result.error_type, + } + + return NodeRunExceptionEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + start_at=event.start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + inputs=event.node_run_result.inputs, + process_data=event.node_run_result.process_data, + outputs=outputs, + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, + }, + ), + error=event.error, + ) \ No newline at end of file diff --git a/web/app/components/workflow/nodes/http/types.ts b/web/app/components/workflow/nodes/http/types.ts index a81fd9016d..a91f755ea6 100644 --- a/web/app/components/workflow/nodes/http/types.ts +++ b/web/app/components/workflow/nodes/http/types.ts @@ -1,90 +1,90 @@ -import type { CommonNodeType, ValueSelector, Variable } from '@/app/components/workflow/types' - -export enum Method { - get = 'get', - post = 'post', - head = 'head', - patch = 'patch', - put = 'put', - delete = 'delete', -} - -export enum BodyType { - none = 'none', - formData = 'form-data', - xWwwFormUrlencoded = 'x-www-form-urlencoded', - rawText = 'raw-text', - json = 'json', - binary = 'binary', -} - -export type KeyValue = { - id?: string - key: string - value: string - type?: string - file?: ValueSelector -} - -export enum BodyPayloadValueType { - text = 'text', - file = 'file', -} - -export type BodyPayload = { - id?: string - key?: string - type: BodyPayloadValueType - file?: ValueSelector // when type is file - value?: string // when type is text -}[] -export type Body = { - type: BodyType - data: string | BodyPayload // string is deprecated, it would convert to BodyPayload after loaded -} - -export enum AuthorizationType { - none = 'no-auth', - apiKey = 'api-key', -} - -export enum APIType { - basic = 'basic', - bearer = 'bearer', - custom = 'custom', -} - -export type Authorization = { - type: AuthorizationType - config?: { - type: APIType - api_key: string - header?: string - } | null -} - -export type Timeout = { - connect?: number - read?: number - write?: number - max_connect_timeout?: number - max_read_timeout?: number - max_write_timeout?: number -} - -export type HttpNodeType = CommonNodeType & { - variables: Variable[] - method: Method - url: string - headers: string - params: string - body: Body - authorization: Authorization - timeout: Timeout - ssl_verify?: boolean - retry_config?: { - max_retries: number - retry_interval: number - retry_enabled: boolean - } +import type { CommonNodeType, ValueSelector, Variable } from '@/app/components/workflow/types' + +export enum Method { + get = 'get', + post = 'post', + head = 'head', + patch = 'patch', + put = 'put', + delete = 'delete', +} + +export enum BodyType { + none = 'none', + formData = 'form-data', + xWwwFormUrlencoded = 'x-www-form-urlencoded', + rawText = 'raw-text', + json = 'json', + binary = 'binary', +} + +export type KeyValue = { + id?: string + key: string + value: string + type?: string + file?: ValueSelector +} + +export enum BodyPayloadValueType { + text = 'text', + file = 'file', +} + +export type BodyPayload = { + id?: string + key?: string + type: BodyPayloadValueType + file?: ValueSelector // when type is file + value?: string // when type is text +}[] +export type Body = { + type: BodyType + data: string | BodyPayload // string is deprecated, it would convert to BodyPayload after loaded +} + +export enum AuthorizationType { + none = 'no-auth', + apiKey = 'api-key', +} + +export enum APIType { + basic = 'basic', + bearer = 'bearer', + custom = 'custom', +} + +export type Authorization = { + type: AuthorizationType + config?: { + type: APIType + api_key: string + header?: string + } | null +} + +export type Timeout = { + connect?: number + read?: number + write?: number + max_connect_timeout?: number + max_read_timeout?: number + max_write_timeout?: number +} + +export type HttpNodeType = CommonNodeType & { + variables: Variable[] + method: Method + url: string + headers: string + params: string + body: Body + authorization: Authorization + timeout: Timeout + ssl_verify?: boolean + retry_config?: { + max_retries: number + retry_interval: number + retry_enabled: boolean + } } \ No newline at end of file