diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e035db8b00..4a4a2688a6 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -5,7 +5,7 @@ Proxy requests to avoid SSRF import contextlib import logging import time -from collections.abc import Generator +from collections.abc import Iterator from typing import Any, TypeAlias import httpx @@ -273,7 +273,7 @@ ssrf_proxy = SSRFProxy() @contextlib.contextmanager -def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Response, None, None]: +def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]: """Streaming HTTP request context manager with SSRF protection. Unlike make_request(), this does not implement retry logic because retries @@ -311,4 +311,13 @@ def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Resp kwargs["headers"] = headers with client.stream(method=method, url=url, **kwargs) as response: + # Check for SSRF protection by Squid proxy (same logic as make_request) + if response.status_code in (401, 403): + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) yield response diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index a56b75c052..8a9e9a18db 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -285,7 +285,10 @@ class ApiTool(Tool): self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] ) -> httpx.Response: """ - do http request depending on api bundle + Do a non-streaming HTTP request depending on api bundle. + + For streaming requests, see do_http_request_streaming(). + Both methods share _prepare_request_parts() for request assembly. """ params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters) @@ -336,8 +339,16 @@ class ApiTool(Tool): follow_redirects=True, ) as response: if response.status_code >= 400: - response.read() - raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") + # Read a bounded amount to avoid unbounded memory usage on large error bodies + _MAX_ERROR_BYTES = 1_048_576 # 1 MB + error_body = b"" + for chunk in response.iter_bytes(chunk_size=8192): + error_body += chunk + if len(error_body) >= _MAX_ERROR_BYTES: + error_body = error_body[:_MAX_ERROR_BYTES] + break + error_text = error_body.decode("utf-8", errors="replace") + raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}") content_type = response.headers.get("content-type", "").lower() @@ -375,7 +386,7 @@ class ApiTool(Tool): try: obj = json.loads(data) except (json.JSONDecodeError, ValueError): - # Not JSON, return raw data + logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200]) return data if not isinstance(obj, dict): @@ -422,6 +433,7 @@ class ApiTool(Tool): else: yield self.create_text_message(str(obj)) except (json.JSONDecodeError, ValueError): + logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200]) if line.strip(): yield self.create_text_message(line) diff --git a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py index 4a0d8d74b4..ca76d312f9 100644 --- a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py +++ b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py @@ -7,10 +7,11 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from core.helper import ssrf_proxy from core.tools.custom_tool.tool import ApiTool from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage -from core.tools.errors import ToolInvokeError +from core.tools.errors import ToolInvokeError, ToolSSRFError # --------------------------------------------------------------------------- # Fixtures / helpers @@ -59,9 +60,12 @@ class FakeStreamResponse: lines: list[str] | None = None, chunks: list[str] | None = None, text: str = "", + extra_headers: dict[str, str] | None = None, ): self.status_code = status_code self.headers = {"content-type": content_type} + if extra_headers: + self.headers.update(extra_headers) self.text = text self._lines = lines or [] self._chunks = chunks or [] @@ -72,6 +76,9 @@ class FakeStreamResponse: def iter_text(self, chunk_size=4096): yield from self._chunks + def iter_bytes(self, chunk_size=8192): + yield self.text.encode("utf-8") + def read(self): pass @@ -275,7 +282,7 @@ class TestStreamingErrorHandling: @contextlib.contextmanager def _raise_stream_error(): raise httpx.StreamError("connection reset") - yield # noqa: RET503 + yield with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()): with pytest.raises(ToolInvokeError, match="Stream request failed"): @@ -287,13 +294,72 @@ class TestStreamingErrorHandling: @contextlib.contextmanager def _raise_timeout(): raise httpx.ReadTimeout("read timed out") - yield # noqa: RET503 + yield with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()): with pytest.raises(ToolInvokeError, match="Stream request failed"): list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) +# --------------------------------------------------------------------------- +# SSRF Squid proxy detection in streaming +# --------------------------------------------------------------------------- + + +class _FakeHttpxStreamCM: + """Mimics the context manager returned by httpx.Client.stream().""" + + def __init__(self, response): + self.response = response + + def __enter__(self): + return self.response + + def __exit__(self, *args): + return False + + +class TestStreamingSSRFProtection: + def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None): + resp = MagicMock() + resp.status_code = status_code + resp.headers = {"content-type": "text/html"} + if extra_headers: + resp.headers.update(extra_headers) + return resp + + def test_squid_403_raises_ssrf_error(self): + """stream_request should detect Squid proxy 403 and raise ToolSSRFError.""" + fake_resp = self._make_mock_response(403, {"server": "squid/5.7"}) + mock_client = MagicMock() + mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp) + + with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client): + with pytest.raises(ToolSSRFError, match="SSRF protection"): + with ssrf_proxy.stream_request("POST", "https://internal.example.com"): + pass + + def test_squid_401_via_header_raises_ssrf_error(self): + """stream_request should detect Squid in Via header and raise ToolSSRFError.""" + fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"}) + mock_client = MagicMock() + mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp) + + with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client): + with pytest.raises(ToolSSRFError, match="SSRF protection"): + with ssrf_proxy.stream_request("POST", "https://internal.example.com"): + pass + + def test_non_squid_403_passes_through(self): + """403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller.""" + tool = _make_api_tool(streaming=True) + resp = FakeStreamResponse(403, "text/plain", text="Forbidden") + + with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)): + with pytest.raises(ToolInvokeError, match="403"): + list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) + + # --------------------------------------------------------------------------- # Content-type routing # ---------------------------------------------------------------------------