mirror of https://github.com/langgenius/dify.git
fix(tools): address review feedback for streaming support
- Add Squid proxy SSRF detection in stream_request() (security) - Limit error response body read to 1 MB to prevent DoS (security) - Add debug logging for JSON decode failures in SSE/NDJSON parsers - Update do_http_request() docstring to reference streaming counterpart - Use Iterator type hint instead of Generator for context manager - Add SSRF protection tests for streaming path
This commit is contained in:
parent
a230cf9970
commit
bf0d7fd57c
|
|
@ -5,7 +5,7 @@ Proxy requests to avoid SSRF
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import httpx
|
||||
|
|
@ -273,7 +273,7 @@ ssrf_proxy = SSRFProxy()
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Response, None, None]:
|
||||
def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]:
|
||||
"""Streaming HTTP request context manager with SSRF protection.
|
||||
|
||||
Unlike make_request(), this does not implement retry logic because retries
|
||||
|
|
@ -311,4 +311,13 @@ def stream_request(method: str, url: str, **kwargs: Any) -> Generator[httpx.Resp
|
|||
kwargs["headers"] = headers
|
||||
|
||||
with client.stream(method=method, url=url, **kwargs) as response:
|
||||
# Check for SSRF protection by Squid proxy (same logic as make_request)
|
||||
if response.status_code in (401, 403):
|
||||
server_header = response.headers.get("server", "").lower()
|
||||
via_header = response.headers.get("via", "").lower()
|
||||
if "squid" in server_header or "squid" in via_header:
|
||||
raise ToolSSRFError(
|
||||
f"Access to '{url}' was blocked by SSRF protection. "
|
||||
f"The URL may point to a private or local network address. "
|
||||
)
|
||||
yield response
|
||||
|
|
|
|||
|
|
@ -285,7 +285,10 @@ class ApiTool(Tool):
|
|||
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
Do a non-streaming HTTP request depending on api bundle.
|
||||
|
||||
For streaming requests, see do_http_request_streaming().
|
||||
Both methods share _prepare_request_parts() for request assembly.
|
||||
"""
|
||||
params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters)
|
||||
|
||||
|
|
@ -336,8 +339,16 @@ class ApiTool(Tool):
|
|||
follow_redirects=True,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
response.read()
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
# Read a bounded amount to avoid unbounded memory usage on large error bodies
|
||||
_MAX_ERROR_BYTES = 1_048_576 # 1 MB
|
||||
error_body = b""
|
||||
for chunk in response.iter_bytes(chunk_size=8192):
|
||||
error_body += chunk
|
||||
if len(error_body) >= _MAX_ERROR_BYTES:
|
||||
error_body = error_body[:_MAX_ERROR_BYTES]
|
||||
break
|
||||
error_text = error_body.decode("utf-8", errors="replace")
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}")
|
||||
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
|
|
@ -375,7 +386,7 @@ class ApiTool(Tool):
|
|||
try:
|
||||
obj = json.loads(data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Not JSON, return raw data
|
||||
logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200])
|
||||
return data
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
|
|
@ -422,6 +433,7 @@ class ApiTool(Tool):
|
|||
else:
|
||||
yield self.create_text_message(str(obj))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200])
|
||||
if line.strip():
|
||||
yield self.create_text_message(line)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ from unittest.mock import MagicMock, patch
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.errors import ToolInvokeError, ToolSSRFError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
|
|
@ -59,9 +60,12 @@ class FakeStreamResponse:
|
|||
lines: list[str] | None = None,
|
||||
chunks: list[str] | None = None,
|
||||
text: str = "",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.headers = {"content-type": content_type}
|
||||
if extra_headers:
|
||||
self.headers.update(extra_headers)
|
||||
self.text = text
|
||||
self._lines = lines or []
|
||||
self._chunks = chunks or []
|
||||
|
|
@ -72,6 +76,9 @@ class FakeStreamResponse:
|
|||
def iter_text(self, chunk_size=4096):
|
||||
yield from self._chunks
|
||||
|
||||
def iter_bytes(self, chunk_size=8192):
|
||||
yield self.text.encode("utf-8")
|
||||
|
||||
def read(self):
|
||||
pass
|
||||
|
||||
|
|
@ -275,7 +282,7 @@ class TestStreamingErrorHandling:
|
|||
@contextlib.contextmanager
|
||||
def _raise_stream_error():
|
||||
raise httpx.StreamError("connection reset")
|
||||
yield # noqa: RET503
|
||||
yield
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()):
|
||||
with pytest.raises(ToolInvokeError, match="Stream request failed"):
|
||||
|
|
@ -287,13 +294,72 @@ class TestStreamingErrorHandling:
|
|||
@contextlib.contextmanager
|
||||
def _raise_timeout():
|
||||
raise httpx.ReadTimeout("read timed out")
|
||||
yield # noqa: RET503
|
||||
yield
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()):
|
||||
with pytest.raises(ToolInvokeError, match="Stream request failed"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF Squid proxy detection in streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeHttpxStreamCM:
|
||||
"""Mimics the context manager returned by httpx.Client.stream()."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def __enter__(self):
|
||||
return self.response
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
|
||||
class TestStreamingSSRFProtection:
|
||||
def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.headers = {"content-type": "text/html"}
|
||||
if extra_headers:
|
||||
resp.headers.update(extra_headers)
|
||||
return resp
|
||||
|
||||
def test_squid_403_raises_ssrf_error(self):
|
||||
"""stream_request should detect Squid proxy 403 and raise ToolSSRFError."""
|
||||
fake_resp = self._make_mock_response(403, {"server": "squid/5.7"})
|
||||
mock_client = MagicMock()
|
||||
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
|
||||
|
||||
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
|
||||
with pytest.raises(ToolSSRFError, match="SSRF protection"):
|
||||
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
|
||||
pass
|
||||
|
||||
def test_squid_401_via_header_raises_ssrf_error(self):
|
||||
"""stream_request should detect Squid in Via header and raise ToolSSRFError."""
|
||||
fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"})
|
||||
mock_client = MagicMock()
|
||||
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
|
||||
|
||||
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
|
||||
with pytest.raises(ToolSSRFError, match="SSRF protection"):
|
||||
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
|
||||
pass
|
||||
|
||||
def test_non_squid_403_passes_through(self):
|
||||
"""403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller."""
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(403, "text/plain", text="Forbidden")
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
|
||||
with pytest.raises(ToolInvokeError, match="403"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Content-type routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue