fix(tools): address review feedback for streaming support

- Add Squid proxy SSRF detection in stream_request() (security)
- Limit error response body read to 1 MB to prevent DoS (security)
- Add debug logging for JSON decode failures in SSE/NDJSON parsers
- Update do_http_request() docstring to reference streaming counterpart
- Use Iterator type hint instead of Generator for context manager
- Add SSRF protection tests for streaming path
This commit is contained in:
skhe 2026-03-03 14:25:32 +08:00
parent a230cf9970
commit bf0d7fd57c
3 changed files with 96 additions and 9 deletions

View File

@ -5,7 +5,7 @@ Proxy requests to avoid SSRF
import contextlib
import logging
import time
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any, TypeAlias
import httpx
@ -273,7 +273,7 @@ ssrf_proxy = SSRFProxy()
@contextlib.contextmanager
def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Response, None, None]:
def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]:
"""Streaming HTTP request context manager with SSRF protection.
Unlike make_request(), this does not implement retry logic because retries
@ -311,4 +311,13 @@ def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Resp
kwargs["headers"] = headers
with client.stream(method=method, url=url, **kwargs) as response:
# Check for SSRF protection by Squid proxy (same logic as make_request)
if response.status_code in (401, 403):
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
yield response

View File

@ -285,7 +285,10 @@ class ApiTool(Tool):
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
) -> httpx.Response:
"""
do http request depending on api bundle
Do a non-streaming HTTP request depending on api bundle.
For streaming requests, see do_http_request_streaming().
Both methods share _prepare_request_parts() for request assembly.
"""
params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters)
@ -336,8 +339,16 @@ class ApiTool(Tool):
follow_redirects=True,
) as response:
if response.status_code >= 400:
response.read()
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
# Read a bounded amount to avoid unbounded memory usage on large error bodies
_MAX_ERROR_BYTES = 1_048_576 # 1 MB
error_body = b""
for chunk in response.iter_bytes(chunk_size=8192):
error_body += chunk
if len(error_body) >= _MAX_ERROR_BYTES:
error_body = error_body[:_MAX_ERROR_BYTES]
break
error_text = error_body.decode("utf-8", errors="replace")
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}")
content_type = response.headers.get("content-type", "").lower()
@ -375,7 +386,7 @@ class ApiTool(Tool):
try:
obj = json.loads(data)
except (json.JSONDecodeError, ValueError):
# Not JSON, return raw data
logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200])
return data
if not isinstance(obj, dict):
@ -422,6 +433,7 @@ class ApiTool(Tool):
else:
yield self.create_text_message(str(obj))
except (json.JSONDecodeError, ValueError):
logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200])
if line.strip():
yield self.create_text_message(line)

View File

@ -7,10 +7,11 @@ from unittest.mock import MagicMock, patch
import httpx
import pytest
from core.helper import ssrf_proxy
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage
from core.tools.errors import ToolInvokeError
from core.tools.errors import ToolInvokeError, ToolSSRFError
# ---------------------------------------------------------------------------
# Fixtures / helpers
@ -59,9 +60,12 @@ class FakeStreamResponse:
lines: list[str] | None = None,
chunks: list[str] | None = None,
text: str = "",
extra_headers: dict[str, str] | None = None,
):
self.status_code = status_code
self.headers = {"content-type": content_type}
if extra_headers:
self.headers.update(extra_headers)
self.text = text
self._lines = lines or []
self._chunks = chunks or []
@ -72,6 +76,9 @@ class FakeStreamResponse:
def iter_text(self, chunk_size=4096):
yield from self._chunks
def iter_bytes(self, chunk_size=8192):
yield self.text.encode("utf-8")
def read(self):
pass
@ -275,7 +282,7 @@ class TestStreamingErrorHandling:
@contextlib.contextmanager
def _raise_stream_error():
raise httpx.StreamError("connection reset")
yield # noqa: RET503
yield
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()):
with pytest.raises(ToolInvokeError, match="Stream request failed"):
@ -287,13 +294,72 @@ class TestStreamingErrorHandling:
@contextlib.contextmanager
def _raise_timeout():
raise httpx.ReadTimeout("read timed out")
yield # noqa: RET503
yield
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()):
with pytest.raises(ToolInvokeError, match="Stream request failed"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
# ---------------------------------------------------------------------------
# SSRF Squid proxy detection in streaming
# ---------------------------------------------------------------------------
class _FakeHttpxStreamCM:
"""Mimics the context manager returned by httpx.Client.stream()."""
def __init__(self, response):
self.response = response
def __enter__(self):
return self.response
def __exit__(self, *args):
return False
class TestStreamingSSRFProtection:
def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None):
resp = MagicMock()
resp.status_code = status_code
resp.headers = {"content-type": "text/html"}
if extra_headers:
resp.headers.update(extra_headers)
return resp
def test_squid_403_raises_ssrf_error(self):
"""stream_request should detect Squid proxy 403 and raise ToolSSRFError."""
fake_resp = self._make_mock_response(403, {"server": "squid/5.7"})
mock_client = MagicMock()
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
with pytest.raises(ToolSSRFError, match="SSRF protection"):
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
pass
def test_squid_401_via_header_raises_ssrf_error(self):
"""stream_request should detect Squid in Via header and raise ToolSSRFError."""
fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"})
mock_client = MagicMock()
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
with pytest.raises(ToolSSRFError, match="SSRF protection"):
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
pass
def test_non_squid_403_passes_through(self):
"""403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller."""
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(403, "text/plain", text="Forbidden")
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
with pytest.raises(ToolInvokeError, match="403"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
# ---------------------------------------------------------------------------
# Content-type routing
# ---------------------------------------------------------------------------