From a230cf9970cfe160fe89f686a219c111a0464bb1 Mon Sep 17 00:00:00 2001 From: skhe Date: Tue, 3 Mar 2026 13:52:45 +0800 Subject: [PATCH 1/3] feat(tools): add streaming response support for API tools When an OpenAPI operation includes the `x-dify-streaming: true` extension field, the API tool now makes a streaming HTTP request and yields response chunks in real-time instead of waiting for the complete response. Supported streaming formats: - SSE (text/event-stream) with OpenAI-compatible format auto-detection - NDJSON (application/x-ndjson, application/jsonl) - Plain text (fallback for any other content type) Changes: - ApiToolBundle: added `streaming: bool` field (defaults to false) - ssrf_proxy: added `stream_request()` context manager - ApiTool: refactored request assembly into `_prepare_request_parts()`, added streaming request path and response parsers - parser.py: extract `x-dify-streaming` from OpenAPI operations - Added 23 unit tests covering all streaming paths Backward compatible: streaming defaults to false, existing tools unaffected. Fixes #32886 --- api/core/helper/ssrf_proxy.py | 44 ++ api/core/tools/custom_tool/tool.py | 156 ++++++- api/core/tools/entities/tool_bundle.py | 2 + api/core/tools/utils/parser.py | 4 + .../core/tools/test_api_tool_streaming.py | 395 ++++++++++++++++++ 5 files changed, 592 insertions(+), 9 deletions(-) create mode 100644 api/tests/unit_tests/core/tools/test_api_tool_streaming.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 54068fc28d..e035db8b00 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,8 +2,10 @@ Proxy requests to avoid SSRF """ +import contextlib import logging import time +from collections.abc import Generator from typing import Any, TypeAlias import httpx @@ -268,3 +270,45 @@ class SSRFProxy: ssrf_proxy = SSRFProxy() + + +@contextlib.contextmanager +def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Response, None, None]: + """Streaming HTTP request context manager with SSRF protection. + + Unlike make_request(), this does not implement retry logic because retries + are meaningless for streaming responses -- the stream cannot be replayed. + """ + if "allow_redirects" in kwargs: + allow_redirects = kwargs.pop("allow_redirects") + if "follow_redirects" not in kwargs: + kwargs["follow_redirects"] = allow_redirects + + if "timeout" not in kwargs: + kwargs["timeout"] = httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ) + + verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) + if not isinstance(verify_option, bool): + raise ValueError("ssl_verify must be a boolean") + client = _get_ssrf_client(verify_option) + + try: + headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {}) + except ValidationError as e: + raise ValueError("headers must be a mapping of string keys to string values") from e + headers = _inject_trace_headers(headers) + kwargs["headers"] = headers + + user_provided_host = _get_user_provided_host_header(headers) + if user_provided_host is not None: + headers = {k: v for k, v in headers.items() if k.lower() != "host"} + headers["host"] = user_provided_host + kwargs["headers"] = headers + + with client.stream(method=method, url=url, **kwargs) as response: + yield response diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index c6a84e27c6..a56b75c052 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator from dataclasses import dataclass from os import getenv @@ -15,6 +16,8 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from dify_graph.file.file_manager import download +logger = logging.getLogger(__name__) + API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), @@ -168,14 +171,15 @@ class ApiTool(Tool): else: return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request( + def _prepare_request_parts( self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] - ) -> httpx.Response: + ) -> tuple[dict, Any, dict, list, str, dict[str, Any]]: """ - do http request depending on api bundle - """ - method = method.lower() + Assemble request parts (params, body, cookies, files, url, headers) from + the OpenAPI bundle and supplied parameters. + This is shared by both the streaming and non-streaming request paths. + """ params = {} path_params = {} # FIXME: body should be a dict[str, Any] but it changed a lot in this function @@ -275,6 +279,21 @@ class ApiTool(Tool): if files: headers.pop("Content-Type", None) + return params, body, cookies, files, url, headers + + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: + """ + do http request depending on api bundle + """ + params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters) + + _VALID_METHODS = {"get", "head", "post", "put", "delete", "patch"} + method_lc = method.lower() + if method_lc not in _VALID_METHODS: + raise ValueError(f"Invalid http method {method}") + _METHOD_MAP = { "get": ssrf_proxy.get, "head": ssrf_proxy.head, @@ -283,9 +302,6 @@ class ApiTool(Tool): "delete": ssrf_proxy.delete, "patch": ssrf_proxy.patch, } - method_lc = method.lower() - if method_lc not in _METHOD_MAP: - raise ValueError(f"Invalid http method {method}") response: httpx.Response = _METHOD_MAP[ method_lc ]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926 @@ -301,6 +317,120 @@ class ApiTool(Tool): ) return response + def do_http_request_streaming( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> Generator[ToolInvokeMessage, None, None]: + """Streaming HTTP request, yields ToolInvokeMessage as response chunks arrive.""" + params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters) + + try: + with ssrf_proxy.stream_request( + method=method.upper(), + url=url, + params=params, + headers=headers, + cookies=cookies, + data=body, + files=files or None, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) as response: + if response.status_code >= 400: + response.read() + raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") + + content_type = response.headers.get("content-type", "").lower() + + if "text/event-stream" in content_type: + yield from self._parse_sse_stream(response) + elif "application/x-ndjson" in content_type or "application/jsonl" in content_type: + yield from self._parse_ndjson_stream(response) + else: + yield from self._parse_text_stream(response) + + except (httpx.StreamError, httpx.TimeoutException) as e: + raise ToolInvokeError(f"Stream request failed: {e}") + + def _parse_sse_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]: + """Parse Server-Sent Events stream, yielding text messages.""" + for line in response.iter_lines(): + if not line: + continue + + # Handle data: prefix + if line.startswith("data:"): + data = line[5:].strip() + + # Handle [DONE] terminator (OpenAI convention) + if data == "[DONE]": + return + + # Try to parse as JSON and extract content + text = self._extract_text_from_sse_data(data) + if text: + yield self.create_text_message(text) + + def _extract_text_from_sse_data(self, data: str) -> str: + """Extract text content from SSE data line. Supports OpenAI format and common field names.""" + try: + obj = json.loads(data) + except (json.JSONDecodeError, ValueError): + # Not JSON, return raw data + return data + + if not isinstance(obj, dict): + return data + + # OpenAI chat completion format: choices[].delta.content + choices = obj.get("choices") + if isinstance(choices, list) and len(choices) > 0: + delta = choices[0].get("delta", {}) + if isinstance(delta, dict): + content = delta.get("content") + if content: + return str(content) + # Non-delta format: choices[].text (completion API) + text = choices[0].get("text") + if text: + return str(text) + + # Common field names + for field in ("content", "text", "message", "data"): + value = obj.get(field) + if value and isinstance(value, str): + return value + + # Fallback: return raw data + return data + + def _parse_ndjson_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]: + """Parse newline-delimited JSON stream.""" + for line in response.iter_lines(): + if not line: + continue + try: + obj = json.loads(line) + if isinstance(obj, dict): + # Try common text fields + for field in ("content", "text", "message", "data"): + value = obj.get(field) + if value and isinstance(value, str): + yield self.create_text_message(value) + break + else: + yield self.create_text_message(json.dumps(obj, ensure_ascii=False)) + else: + yield self.create_text_message(str(obj)) + except (json.JSONDecodeError, ValueError): + if line.strip(): + yield self.create_text_message(line) + + def _parse_text_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]: + """Parse plain text stream in chunks.""" + for chunk in response.iter_text(chunk_size=4096): + if chunk: + yield self.create_text_message(chunk) + def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 ): @@ -384,10 +514,18 @@ class ApiTool(Tool): """ invoke http request """ - response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) + # streaming path + if self.api_bundle.streaming: + yield from self.do_http_request_streaming( + self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters + ) + return + + # non-streaming path (original) + response: httpx.Response | str = "" # do http request response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 10710c4376..e6879721b2 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -29,3 +29,5 @@ class ApiToolBundle(BaseModel): openapi: dict # output schema output_schema: Mapping[str, object] = Field(default_factory=dict) + # whether this operation supports streaming response + streaming: bool = Field(default=False) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index fc2b41d960..e5a1cc45b8 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -193,6 +193,9 @@ class ApiBasedToolSchemaParser: interface["operation"]["operationId"] = f"{path}_{interface['method']}" + # Extract x-dify-streaming extension field + streaming = bool(interface["operation"].get("x-dify-streaming", False)) + bundles.append( ApiToolBundle( server_url=server_url + interface["path"], @@ -205,6 +208,7 @@ class ApiBasedToolSchemaParser: author="", icon=None, openapi=interface["operation"], + streaming=streaming, ) ) diff --git a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py new file mode 100644 index 0000000000..4a0d8d74b4 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py @@ -0,0 +1,395 @@ +"""Tests for ApiTool streaming support.""" + +import contextlib +import json +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +def _msg_text(msg: ToolInvokeMessage) -> str: + """Extract plain text from a ToolInvokeMessage regardless of inner type.""" + inner = msg.message + if hasattr(inner, "text"): + return inner.text + return str(inner) + + +def _make_api_tool(streaming: bool = False) -> ApiTool: + bundle = ApiToolBundle( + server_url="https://api.example.com/v1/chat", + method="post", + summary="test", + operation_id="test_op", + parameters=[], + author="test", + icon=None, + openapi={}, + streaming=streaming, + ) + + entity = MagicMock(spec=ToolEntity) + entity.identity = MagicMock() + entity.identity.name = "test_tool" + entity.identity.author = "test" + entity.output_schema = {} + + runtime = MagicMock() + runtime.credentials = {"auth_type": "none"} + runtime.runtime_parameters = {} + + return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="test_provider") + + +class FakeStreamResponse: + def __init__( + self, + status_code: int, + content_type: str, + lines: list[str] | None = None, + chunks: list[str] | None = None, + text: str = "", + ): + self.status_code = status_code + self.headers = {"content-type": content_type} + self.text = text + self._lines = lines or [] + self._chunks = chunks or [] + + def iter_lines(self): + yield from self._lines + + def iter_text(self, chunk_size=4096): + yield from self._chunks + + def read(self): + pass + + +@contextlib.contextmanager +def _fake_stream_ctx(response: FakeStreamResponse): + yield response + + +# --------------------------------------------------------------------------- +# ApiToolBundle.streaming default value +# --------------------------------------------------------------------------- + + +class TestApiToolBundleStreaming: + def test_default_false(self): + bundle = ApiToolBundle( + server_url="https://example.com", + method="get", + operation_id="op", + parameters=[], + author="", + openapi={}, + ) + assert bundle.streaming is False + + def test_explicit_true(self): + bundle = ApiToolBundle( + server_url="https://example.com", + method="get", + operation_id="op", + parameters=[], + author="", + openapi={}, + streaming=True, + ) + assert bundle.streaming is True + + +# --------------------------------------------------------------------------- +# SSE parser +# --------------------------------------------------------------------------- + + +class TestParseSSEStream: + def test_openai_format(self): + tool = _make_api_tool(streaming=True) + lines = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}', + 'data: {"choices":[{"delta":{"content":" world"}}]}', + "data: [DONE]", + ] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(resp)) + assert len(messages) == 2 + assert _msg_text(messages[0]) == "Hello" + assert _msg_text(messages[1]) == " world" + + def test_plain_text_sse(self): + tool = _make_api_tool(streaming=True) + lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(resp)) + assert len(messages) == 2 + assert _msg_text(messages[0]) == "chunk1" + assert _msg_text(messages[1]) == "chunk2" + + def test_common_field_names(self): + tool = _make_api_tool(streaming=True) + lines = [ + 'data: {"content":"from content"}', + 'data: {"text":"from text"}', + 'data: {"message":"from message"}', + ] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(resp)) + assert len(messages) == 3 + assert _msg_text(messages[0]) == "from content" + assert _msg_text(messages[1]) == "from text" + assert _msg_text(messages[2]) == "from message" + + def test_empty_lines_skipped(self): + tool = _make_api_tool(streaming=True) + lines = ["", "", "data: hello", ""] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(resp)) + assert len(messages) == 1 + + def test_non_data_lines_skipped(self): + tool = _make_api_tool(streaming=True) + lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(resp)) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "actual content" + + +# --------------------------------------------------------------------------- +# NDJSON parser +# --------------------------------------------------------------------------- + + +class TestParseNdjsonStream: + def test_json_with_content_field(self): + tool = _make_api_tool(streaming=True) + lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})] + resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) + messages = list(tool._parse_ndjson_stream(resp)) + assert len(messages) == 2 + assert _msg_text(messages[0]) == "line1" + + def test_fallback_to_full_json(self): + tool = _make_api_tool(streaming=True) + lines = [json.dumps({"result": 42})] + resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) + messages = list(tool._parse_ndjson_stream(resp)) + assert len(messages) == 1 + assert "42" in _msg_text(messages[0]) + + def test_invalid_json_lines(self): + tool = _make_api_tool(streaming=True) + lines = ["not json", ""] + resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) + messages = list(tool._parse_ndjson_stream(resp)) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "not json" + + +# --------------------------------------------------------------------------- +# Text stream parser +# --------------------------------------------------------------------------- + + +class TestParseTextStream: + def test_text_chunks(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"]) + messages = list(tool._parse_text_stream(resp)) + assert len(messages) == 3 + assert _msg_text(messages[0]) == "chunk1" + + def test_empty_chunks_skipped(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""]) + messages = list(tool._parse_text_stream(resp)) + assert len(messages) == 1 + + +# --------------------------------------------------------------------------- +# _invoke() streaming branch +# --------------------------------------------------------------------------- + + +class TestInvokeStreamingBranch: + @patch.object(ApiTool, "do_http_request_streaming") + @patch.object(ApiTool, "do_http_request") + def test_streaming_true_uses_streaming_path(self, mock_non_stream, mock_stream): + tool = _make_api_tool(streaming=True) + mock_stream.return_value = iter([tool.create_text_message("streamed")]) + + messages = list(tool._invoke(user_id="u1", tool_parameters={})) + mock_stream.assert_called_once() + mock_non_stream.assert_not_called() + assert len(messages) == 1 + assert _msg_text(messages[0]) == "streamed" + + @patch.object(ApiTool, "do_http_request_streaming") + @patch.object(ApiTool, "do_http_request") + def test_streaming_false_uses_non_streaming_path(self, mock_non_stream, mock_stream): + tool = _make_api_tool(streaming=False) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = b'{"result": "ok"}' + mock_response.headers = {"content-type": "text/plain"} + mock_response.text = "ok" + mock_response.json.side_effect = Exception("not json") + mock_non_stream.return_value = mock_response + + messages = list(tool._invoke(user_id="u1", tool_parameters={})) + mock_non_stream.assert_called_once() + mock_stream.assert_not_called() + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestStreamingErrorHandling: + def test_http_error_raises(self): + tool = _make_api_tool(streaming=True) + error_resp = FakeStreamResponse(500, "text/plain", text="Internal Server Error") + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(error_resp)): + with pytest.raises(ToolInvokeError, match="500"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + def test_stream_error_raises(self): + tool = _make_api_tool(streaming=True) + + @contextlib.contextmanager + def _raise_stream_error(): + raise httpx.StreamError("connection reset") + yield # noqa: RET503 + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()): + with pytest.raises(ToolInvokeError, match="Stream request failed"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + def test_timeout_raises(self): + tool = _make_api_tool(streaming=True) + + @contextlib.contextmanager + def _raise_timeout(): + raise httpx.ReadTimeout("read timed out") + yield # noqa: RET503 + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()): + with pytest.raises(ToolInvokeError, match="Stream request failed"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + +# --------------------------------------------------------------------------- +# Content-type routing +# --------------------------------------------------------------------------- + + +class TestContentTypeRouting: + def test_sse_content_type(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "text/event-stream; charset=utf-8", lines=["data: hello"]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + + def test_ndjson_content_type(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "application/x-ndjson", lines=[json.dumps({"text": "hi"})]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "hi" + + def test_jsonl_content_type(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "application/jsonl", lines=[json.dumps({"content": "hey"})]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + + def test_fallback_text_stream(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "application/octet-stream", chunks=["raw data"]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "raw data" + + +# --------------------------------------------------------------------------- +# OpenAPI parser: x-dify-streaming extraction +# --------------------------------------------------------------------------- + + +class TestOpenAPIParserStreaming: + def test_x_dify_streaming_true(self): + from flask import Flask + + from core.tools.utils.parser import ApiBasedToolSchemaParser + + openapi = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/chat": { + "post": { + "x-dify-streaming": True, + "operationId": "chat", + "summary": "Chat", + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + app = Flask(__name__) + with app.test_request_context(): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(bundles) == 1 + assert bundles[0].streaming is True + + def test_x_dify_streaming_absent(self): + from flask import Flask + + from core.tools.utils.parser import ApiBasedToolSchemaParser + + openapi = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/query": { + "get": { + "operationId": "query", + "summary": "Query", + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + app = Flask(__name__) + with app.test_request_context(): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(bundles) == 1 + assert bundles[0].streaming is False From bf0d7fd57c691df66d941187b99d08554a2a8881 Mon Sep 17 00:00:00 2001 From: skhe Date: Tue, 3 Mar 2026 14:25:32 +0800 Subject: [PATCH 2/3] fix(tools): address review feedback for streaming support - Add Squid proxy SSRF detection in stream_request() (security) - Limit error response body read to 1 MB to prevent DoS (security) - Add debug logging for JSON decode failures in SSE/NDJSON parsers - Update do_http_request() docstring to reference streaming counterpart - Use Iterator type hint instead of Generator for context manager - Add SSRF protection tests for streaming path --- api/core/helper/ssrf_proxy.py | 13 +++- api/core/tools/custom_tool/tool.py | 20 ++++-- .../core/tools/test_api_tool_streaming.py | 72 ++++++++++++++++++- 3 files changed, 96 insertions(+), 9 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e035db8b00..4a4a2688a6 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -5,7 +5,7 @@ Proxy requests to avoid SSRF import contextlib import logging import time -from collections.abc import Generator +from collections.abc import Iterator from typing import Any, TypeAlias import httpx @@ -273,7 +273,7 @@ ssrf_proxy = SSRFProxy() @contextlib.contextmanager -def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Response, None, None]: +def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]: """Streaming HTTP request context manager with SSRF protection. Unlike make_request(), this does not implement retry logic because retries @@ -311,4 +311,13 @@ def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Resp kwargs["headers"] = headers with client.stream(method=method, url=url, **kwargs) as response: + # Check for SSRF protection by Squid proxy (same logic as make_request) + if response.status_code in (401, 403): + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) yield response diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index a56b75c052..8a9e9a18db 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -285,7 +285,10 @@ class ApiTool(Tool): self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] ) -> httpx.Response: """ - do http request depending on api bundle + Do a non-streaming HTTP request depending on api bundle. + + For streaming requests, see do_http_request_streaming(). + Both methods share _prepare_request_parts() for request assembly. """ params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters) @@ -336,8 +339,16 @@ class ApiTool(Tool): follow_redirects=True, ) as response: if response.status_code >= 400: - response.read() - raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") + # Read a bounded amount to avoid unbounded memory usage on large error bodies + _MAX_ERROR_BYTES = 1_048_576 # 1 MB + error_body = b"" + for chunk in response.iter_bytes(chunk_size=8192): + error_body += chunk + if len(error_body) >= _MAX_ERROR_BYTES: + error_body = error_body[:_MAX_ERROR_BYTES] + break + error_text = error_body.decode("utf-8", errors="replace") + raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}") content_type = response.headers.get("content-type", "").lower() @@ -375,7 +386,7 @@ class ApiTool(Tool): try: obj = json.loads(data) except (json.JSONDecodeError, ValueError): - # Not JSON, return raw data + logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200]) return data if not isinstance(obj, dict): @@ -422,6 +433,7 @@ class ApiTool(Tool): else: yield self.create_text_message(str(obj)) except (json.JSONDecodeError, ValueError): + logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200]) if line.strip(): yield self.create_text_message(line) diff --git a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py index 4a0d8d74b4..ca76d312f9 100644 --- a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py +++ b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py @@ -7,10 +7,11 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from core.helper import ssrf_proxy from core.tools.custom_tool.tool import ApiTool from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage -from core.tools.errors import ToolInvokeError +from core.tools.errors import ToolInvokeError, ToolSSRFError # --------------------------------------------------------------------------- # Fixtures / helpers @@ -59,9 +60,12 @@ class FakeStreamResponse: lines: list[str] | None = None, chunks: list[str] | None = None, text: str = "", + extra_headers: dict[str, str] | None = None, ): self.status_code = status_code self.headers = {"content-type": content_type} + if extra_headers: + self.headers.update(extra_headers) self.text = text self._lines = lines or [] self._chunks = chunks or [] @@ -72,6 +76,9 @@ class FakeStreamResponse: def iter_text(self, chunk_size=4096): yield from self._chunks + def iter_bytes(self, chunk_size=8192): + yield self.text.encode("utf-8") + def read(self): pass @@ -275,7 +282,7 @@ class TestStreamingErrorHandling: @contextlib.contextmanager def _raise_stream_error(): raise httpx.StreamError("connection reset") - yield # noqa: RET503 + yield with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()): with pytest.raises(ToolInvokeError, match="Stream request failed"): @@ -287,13 +294,72 @@ class TestStreamingErrorHandling: @contextlib.contextmanager def _raise_timeout(): raise httpx.ReadTimeout("read timed out") - yield # noqa: RET503 + yield with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()): with pytest.raises(ToolInvokeError, match="Stream request failed"): list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) +# --------------------------------------------------------------------------- +# SSRF Squid proxy detection in streaming +# --------------------------------------------------------------------------- + + +class _FakeHttpxStreamCM: + """Mimics the context manager returned by httpx.Client.stream().""" + + def __init__(self, response): + self.response = response + + def __enter__(self): + return self.response + + def __exit__(self, *args): + return False + + +class TestStreamingSSRFProtection: + def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None): + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"content-type": "text/html"} + if extra_headers: + resp.headers.update(extra_headers) + return resp + + def test_squid_403_raises_ssrf_error(self): + """stream_request should detect Squid proxy 403 and raise ToolSSRFError.""" + fake_resp = self._make_mock_response(403, {"server": "squid/5.7"}) + mock_client = MagicMock() + mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp) + + with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client): + with pytest.raises(ToolSSRFError, match="SSRF protection"): + with ssrf_proxy.stream_request("POST", "https://internal.example.com"): + pass + + def test_squid_401_via_header_raises_ssrf_error(self): + """stream_request should detect Squid in Via header and raise ToolSSRFError.""" + fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"}) + mock_client = MagicMock() + mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp) + + with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client): + with pytest.raises(ToolSSRFError, match="SSRF protection"): + with ssrf_proxy.stream_request("POST", "https://internal.example.com"): + pass + + def test_non_squid_403_passes_through(self): + """403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller.""" + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(403, "text/plain", text="Forbidden") + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + with pytest.raises(ToolInvokeError, match="403"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + # --------------------------------------------------------------------------- # Content-type routing # --------------------------------------------------------------------------- From abb0408b082da09388b052595f2fe812ae962eeb Mon Sep 17 00:00:00 2001 From: skhe Date: Tue, 3 Mar 2026 17:01:24 +0800 Subject: [PATCH 3/3] fix(tools): resolve Pyrefly type check errors in streaming tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use typing.cast via _resp() helper for FakeStreamResponse → httpx.Response - Replace @contextlib.contextmanager with _RaisingContextManager class to eliminate unreachable yield statements --- .../core/tools/test_api_tool_streaming.py | 55 +++++++++++-------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py index ca76d312f9..30f461cdbf 100644 --- a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py +++ b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py @@ -2,6 +2,7 @@ import contextlib import json +from typing import cast from unittest.mock import MagicMock, patch import httpx @@ -83,11 +84,29 @@ class FakeStreamResponse: pass +def _resp(fake: FakeStreamResponse) -> httpx.Response: + """Cast FakeStreamResponse to httpx.Response for type checkers.""" + return cast(httpx.Response, fake) + + @contextlib.contextmanager def _fake_stream_ctx(response: FakeStreamResponse): yield response +class _RaisingContextManager: + """Context manager that raises on __enter__. Used to test exception paths.""" + + def __init__(self, exception: BaseException): + self.exception = exception + + def __enter__(self): + raise self.exception + + def __exit__(self, *args): + return False + + # --------------------------------------------------------------------------- # ApiToolBundle.streaming default value # --------------------------------------------------------------------------- @@ -132,7 +151,7 @@ class TestParseSSEStream: "data: [DONE]", ] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 2 assert _msg_text(messages[0]) == "Hello" assert _msg_text(messages[1]) == " world" @@ -141,7 +160,7 @@ class TestParseSSEStream: tool = _make_api_tool(streaming=True) lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 2 assert _msg_text(messages[0]) == "chunk1" assert _msg_text(messages[1]) == "chunk2" @@ -154,7 +173,7 @@ class TestParseSSEStream: 'data: {"message":"from message"}', ] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 3 assert _msg_text(messages[0]) == "from content" assert _msg_text(messages[1]) == "from text" @@ -164,14 +183,14 @@ class TestParseSSEStream: tool = _make_api_tool(streaming=True) lines = ["", "", "data: hello", ""] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 1 def test_non_data_lines_skipped(self): tool = _make_api_tool(streaming=True) lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 1 assert _msg_text(messages[0]) == "actual content" @@ -186,7 +205,7 @@ class TestParseNdjsonStream: tool = _make_api_tool(streaming=True) lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})] resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) - messages = list(tool._parse_ndjson_stream(resp)) + messages = list(tool._parse_ndjson_stream(_resp(resp))) assert len(messages) == 2 assert _msg_text(messages[0]) == "line1" @@ -194,7 +213,7 @@ class TestParseNdjsonStream: tool = _make_api_tool(streaming=True) lines = [json.dumps({"result": 42})] resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) - messages = list(tool._parse_ndjson_stream(resp)) + messages = list(tool._parse_ndjson_stream(_resp(resp))) assert len(messages) == 1 assert "42" in _msg_text(messages[0]) @@ -202,7 +221,7 @@ class TestParseNdjsonStream: tool = _make_api_tool(streaming=True) lines = ["not json", ""] resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) - messages = list(tool._parse_ndjson_stream(resp)) + messages = list(tool._parse_ndjson_stream(_resp(resp))) assert len(messages) == 1 assert _msg_text(messages[0]) == "not json" @@ -216,14 +235,14 @@ class TestParseTextStream: def test_text_chunks(self): tool = _make_api_tool(streaming=True) resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"]) - messages = list(tool._parse_text_stream(resp)) + messages = list(tool._parse_text_stream(_resp(resp))) assert len(messages) == 3 assert _msg_text(messages[0]) == "chunk1" def test_empty_chunks_skipped(self): tool = _make_api_tool(streaming=True) resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""]) - messages = list(tool._parse_text_stream(resp)) + messages = list(tool._parse_text_stream(_resp(resp))) assert len(messages) == 1 @@ -278,25 +297,17 @@ class TestStreamingErrorHandling: def test_stream_error_raises(self): tool = _make_api_tool(streaming=True) + ctx = _RaisingContextManager(httpx.StreamError("connection reset")) - @contextlib.contextmanager - def _raise_stream_error(): - raise httpx.StreamError("connection reset") - yield - - with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()): + with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx): with pytest.raises(ToolInvokeError, match="Stream request failed"): list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) def test_timeout_raises(self): tool = _make_api_tool(streaming=True) + ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out")) - @contextlib.contextmanager - def _raise_timeout(): - raise httpx.ReadTimeout("read timed out") - yield - - with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()): + with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx): with pytest.raises(ToolInvokeError, match="Stream request failed"): list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))