refactor(api): simplify response session eligibility (#33538)

This commit is contained in:
-LAN- 2026-03-16 21:22:37 +08:00 committed by GitHub
parent c7f86dba09
commit e445f69604
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 41 deletions

View File

@ -6,6 +6,5 @@ of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
from .session import RESPONSE_SESSION_NODE_TYPES
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]
__all__ = ["ResponseStreamCoordinator"]

View File

@ -3,10 +3,6 @@ Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
can opt additional response-capable node types into session creation without
patching the coordinator.
"""
from __future__ import annotations
@ -14,7 +10,6 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, cast
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.base.template import Template
from dify_graph.runtime.graph_runtime_state import NodeProtocol
@ -25,12 +20,6 @@ class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
def get_streaming_template(self) -> Template: ...
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.END,
]
@dataclass
class ResponseSession:
"""
@ -49,8 +38,8 @@ class ResponseSession:
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
and which implements `get_streaming_template()`.
At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which
graph nodes should be treated as response-capable before they reach this factory.
Args:
node: Node from the materialized workflow graph.
@ -59,15 +48,8 @@ class ResponseSession:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not a supported response node type.
TypeError: If node does not implement the response-session streaming contract.
"""
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
raise TypeError(
"ResponseSession.from_node only supports node types in "
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
)
response_node = cast(_ResponseSessionNodeProtocol, node)
try:
template = response_node.get_streaming_template()

View File

@ -4,9 +4,7 @@ from __future__ import annotations
import pytest
import dify_graph.graph_engine.response_coordinator.session as response_session_module
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType
from dify_graph.graph_engine.response_coordinator import RESPONSE_SESSION_NODE_TYPES
from dify_graph.graph_engine.response_coordinator.session import ResponseSession
from dify_graph.nodes.base.template import Template, TextSegment
@ -35,28 +33,14 @@ class DummyNodeWithoutStreamingTemplate:
self.state = NodeState.UNKNOWN
def test_response_session_from_node_rejects_node_types_outside_allowlist() -> None:
"""Unsupported node types are rejected even if they expose a template."""
def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None:
"""Session creation depends on the streaming-template contract rather than node type."""
node = DummyResponseNode(
node_id="llm-node",
node_type=BuiltinNodeTypes.LLM,
template=Template(segments=[TextSegment(text="hello")]),
)
with pytest.raises(TypeError, match="RESPONSE_SESSION_NODE_TYPES"):
ResponseSession.from_node(node)
def test_response_session_from_node_supports_downstream_allowlist_extension(monkeypatch) -> None:
"""Downstream applications can extend the supported node-type list."""
node = DummyResponseNode(
node_id="llm-node",
node_type=BuiltinNodeTypes.LLM,
template=Template(segments=[TextSegment(text="hello")]),
)
extended_node_types = [*RESPONSE_SESSION_NODE_TYPES, BuiltinNodeTypes.LLM]
monkeypatch.setattr(response_session_module, "RESPONSE_SESSION_NODE_TYPES", extended_node_types)
session = ResponseSession.from_node(node)
assert session.node_id == "llm-node"