diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 54068fc28d..4a4a2688a6 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,8 +2,10 @@ Proxy requests to avoid SSRF """ +import contextlib import logging import time +from collections.abc import Iterator from typing import Any, TypeAlias import httpx @@ -268,3 +270,54 @@ class SSRFProxy: ssrf_proxy = SSRFProxy() + + +@contextlib.contextmanager +def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]: + """Streaming HTTP request context manager with SSRF protection. + + Unlike make_request(), this does not implement retry logic because retries + are meaningless for streaming responses -- the stream cannot be replayed. + """ + if "allow_redirects" in kwargs: + allow_redirects = kwargs.pop("allow_redirects") + if "follow_redirects" not in kwargs: + kwargs["follow_redirects"] = allow_redirects + + if "timeout" not in kwargs: + kwargs["timeout"] = httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ) + + verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) + if not isinstance(verify_option, bool): + raise ValueError("ssl_verify must be a boolean") + client = _get_ssrf_client(verify_option) + + try: + headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {}) + except ValidationError as e: + raise ValueError("headers must be a mapping of string keys to string values") from e + headers = _inject_trace_headers(headers) + kwargs["headers"] = headers + + user_provided_host = _get_user_provided_host_header(headers) + if user_provided_host is not None: + headers = {k: v for k, v in headers.items() if k.lower() != "host"} + headers["host"] = user_provided_host + kwargs["headers"] = headers + + with client.stream(method=method, url=url, **kwargs) as response: + # Check for SSRF protection by Squid proxy (same logic as make_request) + if response.status_code in (401, 403): + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) + yield response diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index c6a84e27c6..8a9e9a18db 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator from dataclasses import dataclass from os import getenv @@ -15,6 +16,8 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from dify_graph.file.file_manager import download +logger = logging.getLogger(__name__) + API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), @@ -168,14 +171,15 @@ class ApiTool(Tool): else: return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request( + def _prepare_request_parts( self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] - ) -> httpx.Response: + ) -> tuple[dict, Any, dict, list, str, dict[str, Any]]: """ - do http request depending on api bundle - """ - method = method.lower() + Assemble request parts (params, body, cookies, files, url, headers) from + the OpenAPI bundle and supplied parameters. + This is shared by both the streaming and non-streaming request paths. + """ params = {} path_params = {} # FIXME: body should be a dict[str, Any] but it changed a lot in this function @@ -275,6 +279,24 @@ class ApiTool(Tool): if files: headers.pop("Content-Type", None) + return params, body, cookies, files, url, headers + + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: + """ + Do a non-streaming HTTP request depending on api bundle. + + For streaming requests, see do_http_request_streaming(). + Both methods share _prepare_request_parts() for request assembly. + """ + params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters) + + _VALID_METHODS = {"get", "head", "post", "put", "delete", "patch"} + method_lc = method.lower() + if method_lc not in _VALID_METHODS: + raise ValueError(f"Invalid http method {method}") + _METHOD_MAP = { "get": ssrf_proxy.get, "head": ssrf_proxy.head, @@ -283,9 +305,6 @@ class ApiTool(Tool): "delete": ssrf_proxy.delete, "patch": ssrf_proxy.patch, } - method_lc = method.lower() - if method_lc not in _METHOD_MAP: - raise ValueError(f"Invalid http method {method}") response: httpx.Response = _METHOD_MAP[ method_lc ]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926 @@ -301,6 +320,129 @@ class ApiTool(Tool): ) return response + def do_http_request_streaming( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> Generator[ToolInvokeMessage, None, None]: + """Streaming HTTP request, yields ToolInvokeMessage as response chunks arrive.""" + params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters) + + try: + with ssrf_proxy.stream_request( + method=method.upper(), + url=url, + params=params, + headers=headers, + cookies=cookies, + data=body, + files=files or None, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) as response: + if response.status_code >= 400: + # Read a bounded amount to avoid unbounded memory usage on large error bodies + _MAX_ERROR_BYTES = 1_048_576 # 1 MB + error_body = b"" + for chunk in response.iter_bytes(chunk_size=8192): + error_body += chunk + if len(error_body) >= _MAX_ERROR_BYTES: + error_body = error_body[:_MAX_ERROR_BYTES] + break + error_text = error_body.decode("utf-8", errors="replace") + raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}") + + content_type = response.headers.get("content-type", "").lower() + + if "text/event-stream" in content_type: + yield from self._parse_sse_stream(response) + elif "application/x-ndjson" in content_type or "application/jsonl" in content_type: + yield from self._parse_ndjson_stream(response) + else: + yield from self._parse_text_stream(response) + + except (httpx.StreamError, httpx.TimeoutException) as e: + raise ToolInvokeError(f"Stream request failed: {e}") + + def _parse_sse_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]: + """Parse Server-Sent Events stream, yielding text messages.""" + for line in response.iter_lines(): + if not line: + continue + + # Handle data: prefix + if line.startswith("data:"): + data = line[5:].strip() + + # Handle [DONE] terminator (OpenAI convention) + if data == "[DONE]": + return + + # Try to parse as JSON and extract content + text = self._extract_text_from_sse_data(data) + if text: + yield self.create_text_message(text) + + def _extract_text_from_sse_data(self, data: str) -> str: + """Extract text content from SSE data line. Supports OpenAI format and common field names.""" + try: + obj = json.loads(data) + except (json.JSONDecodeError, ValueError): + logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200]) + return data + + if not isinstance(obj, dict): + return data + + # OpenAI chat completion format: choices[].delta.content + choices = obj.get("choices") + if isinstance(choices, list) and len(choices) > 0: + delta = choices[0].get("delta", {}) + if isinstance(delta, dict): + content = delta.get("content") + if content: + return str(content) + # Non-delta format: choices[].text (completion API) + text = choices[0].get("text") + if text: + return str(text) + + # Common field names + for field in ("content", "text", "message", "data"): + value = obj.get(field) + if value and isinstance(value, str): + return value + + # Fallback: return raw data + return data + + def _parse_ndjson_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]: + """Parse newline-delimited JSON stream.""" + for line in response.iter_lines(): + if not line: + continue + try: + obj = json.loads(line) + if isinstance(obj, dict): + # Try common text fields + for field in ("content", "text", "message", "data"): + value = obj.get(field) + if value and isinstance(value, str): + yield self.create_text_message(value) + break + else: + yield self.create_text_message(json.dumps(obj, ensure_ascii=False)) + else: + yield self.create_text_message(str(obj)) + except (json.JSONDecodeError, ValueError): + logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200]) + if line.strip(): + yield self.create_text_message(line) + + def _parse_text_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]: + """Parse plain text stream in chunks.""" + for chunk in response.iter_text(chunk_size=4096): + if chunk: + yield self.create_text_message(chunk) + def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 ): @@ -384,10 +526,18 @@ class ApiTool(Tool): """ invoke http request """ - response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) + # streaming path + if self.api_bundle.streaming: + yield from self.do_http_request_streaming( + self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters + ) + return + + # non-streaming path (original) + response: httpx.Response | str = "" # do http request response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 10710c4376..e6879721b2 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -29,3 +29,5 @@ class ApiToolBundle(BaseModel): openapi: dict # output schema output_schema: Mapping[str, object] = Field(default_factory=dict) + # whether this operation supports streaming response + streaming: bool = Field(default=False) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f7484b93fb..16ec9530bc 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -202,6 +202,9 @@ class ApiBasedToolSchemaParser: interface["operation"]["operationId"] = f"{path}_{interface['method']}" + # Extract x-dify-streaming extension field + streaming = bool(interface["operation"].get("x-dify-streaming", False)) + bundles.append( ApiToolBundle( server_url=server_url + interface["path"], @@ -214,6 +217,7 @@ class ApiBasedToolSchemaParser: author="", icon=None, openapi=interface["operation"], + streaming=streaming, ) ) diff --git a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py new file mode 100644 index 0000000000..30f461cdbf --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py @@ -0,0 +1,472 @@ +"""Tests for ApiTool streaming support.""" + +import contextlib +import json +from typing import cast +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from core.helper import ssrf_proxy +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError, ToolSSRFError + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +def _msg_text(msg: ToolInvokeMessage) -> str: + """Extract plain text from a ToolInvokeMessage regardless of inner type.""" + inner = msg.message + if hasattr(inner, "text"): + return inner.text + return str(inner) + + +def _make_api_tool(streaming: bool = False) -> ApiTool: + bundle = ApiToolBundle( + server_url="https://api.example.com/v1/chat", + method="post", + summary="test", + operation_id="test_op", + parameters=[], + author="test", + icon=None, + openapi={}, + streaming=streaming, + ) + + entity = MagicMock(spec=ToolEntity) + entity.identity = MagicMock() + entity.identity.name = "test_tool" + entity.identity.author = "test" + entity.output_schema = {} + + runtime = MagicMock() + runtime.credentials = {"auth_type": "none"} + runtime.runtime_parameters = {} + + return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="test_provider") + + +class FakeStreamResponse: + def __init__( + self, + status_code: int, + content_type: str, + lines: list[str] | None = None, + chunks: list[str] | None = None, + text: str = "", + extra_headers: dict[str, str] | None = None, + ): + self.status_code = status_code + self.headers = {"content-type": content_type} + if extra_headers: + self.headers.update(extra_headers) + self.text = text + self._lines = lines or [] + self._chunks = chunks or [] + + def iter_lines(self): + yield from self._lines + + def iter_text(self, chunk_size=4096): + yield from self._chunks + + def iter_bytes(self, chunk_size=8192): + yield self.text.encode("utf-8") + + def read(self): + pass + + +def _resp(fake: FakeStreamResponse) -> httpx.Response: + """Cast FakeStreamResponse to httpx.Response for type checkers.""" + return cast(httpx.Response, fake) + + +@contextlib.contextmanager +def _fake_stream_ctx(response: FakeStreamResponse): + yield response + + +class _RaisingContextManager: + """Context manager that raises on __enter__. Used to test exception paths.""" + + def __init__(self, exception: BaseException): + self.exception = exception + + def __enter__(self): + raise self.exception + + def __exit__(self, *args): + return False + + +# --------------------------------------------------------------------------- +# ApiToolBundle.streaming default value +# --------------------------------------------------------------------------- + + +class TestApiToolBundleStreaming: + def test_default_false(self): + bundle = ApiToolBundle( + server_url="https://example.com", + method="get", + operation_id="op", + parameters=[], + author="", + openapi={}, + ) + assert bundle.streaming is False + + def test_explicit_true(self): + bundle = ApiToolBundle( + server_url="https://example.com", + method="get", + operation_id="op", + parameters=[], + author="", + openapi={}, + streaming=True, + ) + assert bundle.streaming is True + + +# --------------------------------------------------------------------------- +# SSE parser +# --------------------------------------------------------------------------- + + +class TestParseSSEStream: + def test_openai_format(self): + tool = _make_api_tool(streaming=True) + lines = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}', + 'data: {"choices":[{"delta":{"content":" world"}}]}', + "data: [DONE]", + ] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(_resp(resp))) + assert len(messages) == 2 + assert _msg_text(messages[0]) == "Hello" + assert _msg_text(messages[1]) == " world" + + def test_plain_text_sse(self): + tool = _make_api_tool(streaming=True) + lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(_resp(resp))) + assert len(messages) == 2 + assert _msg_text(messages[0]) == "chunk1" + assert _msg_text(messages[1]) == "chunk2" + + def test_common_field_names(self): + tool = _make_api_tool(streaming=True) + lines = [ + 'data: {"content":"from content"}', + 'data: {"text":"from text"}', + 'data: {"message":"from message"}', + ] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(_resp(resp))) + assert len(messages) == 3 + assert _msg_text(messages[0]) == "from content" + assert _msg_text(messages[1]) == "from text" + assert _msg_text(messages[2]) == "from message" + + def test_empty_lines_skipped(self): + tool = _make_api_tool(streaming=True) + lines = ["", "", "data: hello", ""] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(_resp(resp))) + assert len(messages) == 1 + + def test_non_data_lines_skipped(self): + tool = _make_api_tool(streaming=True) + lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"] + resp = FakeStreamResponse(200, "text/event-stream", lines=lines) + messages = list(tool._parse_sse_stream(_resp(resp))) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "actual content" + + +# --------------------------------------------------------------------------- +# NDJSON parser +# --------------------------------------------------------------------------- + + +class TestParseNdjsonStream: + def test_json_with_content_field(self): + tool = _make_api_tool(streaming=True) + lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})] + resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) + messages = list(tool._parse_ndjson_stream(_resp(resp))) + assert len(messages) == 2 + assert _msg_text(messages[0]) == "line1" + + def test_fallback_to_full_json(self): + tool = _make_api_tool(streaming=True) + lines = [json.dumps({"result": 42})] + resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) + messages = list(tool._parse_ndjson_stream(_resp(resp))) + assert len(messages) == 1 + assert "42" in _msg_text(messages[0]) + + def test_invalid_json_lines(self): + tool = _make_api_tool(streaming=True) + lines = ["not json", ""] + resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) + messages = list(tool._parse_ndjson_stream(_resp(resp))) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "not json" + + +# --------------------------------------------------------------------------- +# Text stream parser +# --------------------------------------------------------------------------- + + +class TestParseTextStream: + def test_text_chunks(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"]) + messages = list(tool._parse_text_stream(_resp(resp))) + assert len(messages) == 3 + assert _msg_text(messages[0]) == "chunk1" + + def test_empty_chunks_skipped(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""]) + messages = list(tool._parse_text_stream(_resp(resp))) + assert len(messages) == 1 + + +# --------------------------------------------------------------------------- +# _invoke() streaming branch +# --------------------------------------------------------------------------- + + +class TestInvokeStreamingBranch: + @patch.object(ApiTool, "do_http_request_streaming") + @patch.object(ApiTool, "do_http_request") + def test_streaming_true_uses_streaming_path(self, mock_non_stream, mock_stream): + tool = _make_api_tool(streaming=True) + mock_stream.return_value = iter([tool.create_text_message("streamed")]) + + messages = list(tool._invoke(user_id="u1", tool_parameters={})) + mock_stream.assert_called_once() + mock_non_stream.assert_not_called() + assert len(messages) == 1 + assert _msg_text(messages[0]) == "streamed" + + @patch.object(ApiTool, "do_http_request_streaming") + @patch.object(ApiTool, "do_http_request") + def test_streaming_false_uses_non_streaming_path(self, mock_non_stream, mock_stream): + tool = _make_api_tool(streaming=False) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.content = b'{"result": "ok"}' + mock_response.headers = {"content-type": "text/plain"} + mock_response.text = "ok" + mock_response.json.side_effect = Exception("not json") + mock_non_stream.return_value = mock_response + + messages = list(tool._invoke(user_id="u1", tool_parameters={})) + mock_non_stream.assert_called_once() + mock_stream.assert_not_called() + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestStreamingErrorHandling: + def test_http_error_raises(self): + tool = _make_api_tool(streaming=True) + error_resp = FakeStreamResponse(500, "text/plain", text="Internal Server Error") + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(error_resp)): + with pytest.raises(ToolInvokeError, match="500"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + def test_stream_error_raises(self): + tool = _make_api_tool(streaming=True) + ctx = _RaisingContextManager(httpx.StreamError("connection reset")) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx): + with pytest.raises(ToolInvokeError, match="Stream request failed"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + def test_timeout_raises(self): + tool = _make_api_tool(streaming=True) + ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out")) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx): + with pytest.raises(ToolInvokeError, match="Stream request failed"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + +# --------------------------------------------------------------------------- +# SSRF Squid proxy detection in streaming +# --------------------------------------------------------------------------- + + +class _FakeHttpxStreamCM: + """Mimics the context manager returned by httpx.Client.stream().""" + + def __init__(self, response): + self.response = response + + def __enter__(self): + return self.response + + def __exit__(self, *args): + return False + + +class TestStreamingSSRFProtection: + def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None): + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"content-type": "text/html"} + if extra_headers: + resp.headers.update(extra_headers) + return resp + + def test_squid_403_raises_ssrf_error(self): + """stream_request should detect Squid proxy 403 and raise ToolSSRFError.""" + fake_resp = self._make_mock_response(403, {"server": "squid/5.7"}) + mock_client = MagicMock() + mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp) + + with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client): + with pytest.raises(ToolSSRFError, match="SSRF protection"): + with ssrf_proxy.stream_request("POST", "https://internal.example.com"): + pass + + def test_squid_401_via_header_raises_ssrf_error(self): + """stream_request should detect Squid in Via header and raise ToolSSRFError.""" + fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"}) + mock_client = MagicMock() + mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp) + + with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client): + with pytest.raises(ToolSSRFError, match="SSRF protection"): + with ssrf_proxy.stream_request("POST", "https://internal.example.com"): + pass + + def test_non_squid_403_passes_through(self): + """403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller.""" + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(403, "text/plain", text="Forbidden") + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + with pytest.raises(ToolInvokeError, match="403"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + +# --------------------------------------------------------------------------- +# Content-type routing +# --------------------------------------------------------------------------- + + +class TestContentTypeRouting: + def test_sse_content_type(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "text/event-stream; charset=utf-8", lines=["data: hello"]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + + def test_ndjson_content_type(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "application/x-ndjson", lines=[json.dumps({"text": "hi"})]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "hi" + + def test_jsonl_content_type(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "application/jsonl", lines=[json.dumps({"content": "hey"})]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + + def test_fallback_text_stream(self): + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(200, "application/octet-stream", chunks=["raw data"]) + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + assert len(messages) == 1 + assert _msg_text(messages[0]) == "raw data" + + +# --------------------------------------------------------------------------- +# OpenAPI parser: x-dify-streaming extraction +# --------------------------------------------------------------------------- + + +class TestOpenAPIParserStreaming: + def test_x_dify_streaming_true(self): + from flask import Flask + + from core.tools.utils.parser import ApiBasedToolSchemaParser + + openapi = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/chat": { + "post": { + "x-dify-streaming": True, + "operationId": "chat", + "summary": "Chat", + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + app = Flask(__name__) + with app.test_request_context(): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(bundles) == 1 + assert bundles[0].streaming is True + + def test_x_dify_streaming_absent(self): + from flask import Flask + + from core.tools.utils.parser import ApiBasedToolSchemaParser + + openapi = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/query": { + "get": { + "operationId": "query", + "summary": "Query", + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + app = Flask(__name__) + with app.test_request_context(): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(bundles) == 1 + assert bundles[0].streaming is False