mirror of https://github.com/langgenius/dify.git
Merge abb0408b08 into fbd558762d
This commit is contained in:
commit
8bbbb90df9
|
|
@ -2,8 +2,10 @@
|
|||
Proxy requests to avoid SSRF
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import httpx
|
||||
|
|
@ -268,3 +270,54 @@ class SSRFProxy:
|
|||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]:
|
||||
"""Streaming HTTP request context manager with SSRF protection.
|
||||
|
||||
Unlike make_request(), this does not implement retry logic because retries
|
||||
are meaningless for streaming responses -- the stream cannot be replayed.
|
||||
"""
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
if "follow_redirects" not in kwargs:
|
||||
kwargs["follow_redirects"] = allow_redirects
|
||||
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
)
|
||||
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
if not isinstance(verify_option, bool):
|
||||
raise ValueError("ssl_verify must be a boolean")
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
try:
|
||||
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
|
||||
except ValidationError as e:
|
||||
raise ValueError("headers must be a mapping of string keys to string values") from e
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
if user_provided_host is not None:
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
headers["host"] = user_provided_host
|
||||
kwargs["headers"] = headers
|
||||
|
||||
with client.stream(method=method, url=url, **kwargs) as response:
|
||||
# Check for SSRF protection by Squid proxy (same logic as make_request)
|
||||
if response.status_code in (401, 403):
|
||||
server_header = response.headers.get("server", "").lower()
|
||||
via_header = response.headers.get("via", "").lower()
|
||||
if "squid" in server_header or "squid" in via_header:
|
||||
raise ToolSSRFError(
|
||||
f"Access to '{url}' was blocked by SSRF protection. "
|
||||
f"The URL may point to a private or local network address. "
|
||||
)
|
||||
yield response
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
|
|
@ -15,6 +16,8 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
|||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from dify_graph.file.file_manager import download
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (
|
||||
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
|
||||
int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
|
||||
|
|
@ -168,14 +171,15 @@ class ApiTool(Tool):
|
|||
else:
|
||||
return (parameter.get("schema", {}) or {}).get("default", "")
|
||||
|
||||
def do_http_request(
|
||||
def _prepare_request_parts(
|
||||
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
) -> tuple[dict, Any, dict, list, str, dict[str, Any]]:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
method = method.lower()
|
||||
Assemble request parts (params, body, cookies, files, url, headers) from
|
||||
the OpenAPI bundle and supplied parameters.
|
||||
|
||||
This is shared by both the streaming and non-streaming request paths.
|
||||
"""
|
||||
params = {}
|
||||
path_params = {}
|
||||
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
|
||||
|
|
@ -275,6 +279,24 @@ class ApiTool(Tool):
|
|||
if files:
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
return params, body, cookies, files, url, headers
|
||||
|
||||
def do_http_request(
|
||||
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Do a non-streaming HTTP request depending on api bundle.
|
||||
|
||||
For streaming requests, see do_http_request_streaming().
|
||||
Both methods share _prepare_request_parts() for request assembly.
|
||||
"""
|
||||
params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters)
|
||||
|
||||
_VALID_METHODS = {"get", "head", "post", "put", "delete", "patch"}
|
||||
method_lc = method.lower()
|
||||
if method_lc not in _VALID_METHODS:
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
|
|
@ -283,9 +305,6 @@ class ApiTool(Tool):
|
|||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
response: httpx.Response = _METHOD_MAP[
|
||||
method_lc
|
||||
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
|
||||
|
|
@ -301,6 +320,129 @@ class ApiTool(Tool):
|
|||
)
|
||||
return response
|
||||
|
||||
def do_http_request_streaming(
|
||||
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Streaming HTTP request, yields ToolInvokeMessage as response chunks arrive."""
|
||||
params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters)
|
||||
|
||||
try:
|
||||
with ssrf_proxy.stream_request(
|
||||
method=method.upper(),
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
files=files or None,
|
||||
timeout=API_TOOL_DEFAULT_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
# Read a bounded amount to avoid unbounded memory usage on large error bodies
|
||||
_MAX_ERROR_BYTES = 1_048_576 # 1 MB
|
||||
error_body = b""
|
||||
for chunk in response.iter_bytes(chunk_size=8192):
|
||||
error_body += chunk
|
||||
if len(error_body) >= _MAX_ERROR_BYTES:
|
||||
error_body = error_body[:_MAX_ERROR_BYTES]
|
||||
break
|
||||
error_text = error_body.decode("utf-8", errors="replace")
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}")
|
||||
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
yield from self._parse_sse_stream(response)
|
||||
elif "application/x-ndjson" in content_type or "application/jsonl" in content_type:
|
||||
yield from self._parse_ndjson_stream(response)
|
||||
else:
|
||||
yield from self._parse_text_stream(response)
|
||||
|
||||
except (httpx.StreamError, httpx.TimeoutException) as e:
|
||||
raise ToolInvokeError(f"Stream request failed: {e}")
|
||||
|
||||
def _parse_sse_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Parse Server-Sent Events stream, yielding text messages."""
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Handle data: prefix
|
||||
if line.startswith("data:"):
|
||||
data = line[5:].strip()
|
||||
|
||||
# Handle [DONE] terminator (OpenAI convention)
|
||||
if data == "[DONE]":
|
||||
return
|
||||
|
||||
# Try to parse as JSON and extract content
|
||||
text = self._extract_text_from_sse_data(data)
|
||||
if text:
|
||||
yield self.create_text_message(text)
|
||||
|
||||
def _extract_text_from_sse_data(self, data: str) -> str:
|
||||
"""Extract text content from SSE data line. Supports OpenAI format and common field names."""
|
||||
try:
|
||||
obj = json.loads(data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200])
|
||||
return data
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return data
|
||||
|
||||
# OpenAI chat completion format: choices[].delta.content
|
||||
choices = obj.get("choices")
|
||||
if isinstance(choices, list) and len(choices) > 0:
|
||||
delta = choices[0].get("delta", {})
|
||||
if isinstance(delta, dict):
|
||||
content = delta.get("content")
|
||||
if content:
|
||||
return str(content)
|
||||
# Non-delta format: choices[].text (completion API)
|
||||
text = choices[0].get("text")
|
||||
if text:
|
||||
return str(text)
|
||||
|
||||
# Common field names
|
||||
for field in ("content", "text", "message", "data"):
|
||||
value = obj.get(field)
|
||||
if value and isinstance(value, str):
|
||||
return value
|
||||
|
||||
# Fallback: return raw data
|
||||
return data
|
||||
|
||||
def _parse_ndjson_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Parse newline-delimited JSON stream."""
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if isinstance(obj, dict):
|
||||
# Try common text fields
|
||||
for field in ("content", "text", "message", "data"):
|
||||
value = obj.get(field)
|
||||
if value and isinstance(value, str):
|
||||
yield self.create_text_message(value)
|
||||
break
|
||||
else:
|
||||
yield self.create_text_message(json.dumps(obj, ensure_ascii=False))
|
||||
else:
|
||||
yield self.create_text_message(str(obj))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200])
|
||||
if line.strip():
|
||||
yield self.create_text_message(line)
|
||||
|
||||
def _parse_text_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Parse plain text stream in chunks."""
|
||||
for chunk in response.iter_text(chunk_size=4096):
|
||||
if chunk:
|
||||
yield self.create_text_message(chunk)
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
):
|
||||
|
|
@ -384,10 +526,18 @@ class ApiTool(Tool):
|
|||
"""
|
||||
invoke http request
|
||||
"""
|
||||
response: httpx.Response | str = ""
|
||||
# assemble request
|
||||
headers = self.assembling_request(tool_parameters)
|
||||
|
||||
# streaming path
|
||||
if self.api_bundle.streaming:
|
||||
yield from self.do_http_request_streaming(
|
||||
self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters
|
||||
)
|
||||
return
|
||||
|
||||
# non-streaming path (original)
|
||||
response: httpx.Response | str = ""
|
||||
# do http request
|
||||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,3 +29,5 @@ class ApiToolBundle(BaseModel):
|
|||
openapi: dict
|
||||
# output schema
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
# whether this operation supports streaming response
|
||||
streaming: bool = Field(default=False)
|
||||
|
|
|
|||
|
|
@ -202,6 +202,9 @@ class ApiBasedToolSchemaParser:
|
|||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
# Extract x-dify-streaming extension field
|
||||
streaming = bool(interface["operation"].get("x-dify-streaming", False))
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
|
|
@ -214,6 +217,7 @@ class ApiBasedToolSchemaParser:
|
|||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,472 @@
|
|||
"""Tests for ApiTool streaming support."""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError, ToolSSRFError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _msg_text(msg: ToolInvokeMessage) -> str:
|
||||
"""Extract plain text from a ToolInvokeMessage regardless of inner type."""
|
||||
inner = msg.message
|
||||
if hasattr(inner, "text"):
|
||||
return inner.text
|
||||
return str(inner)
|
||||
|
||||
|
||||
def _make_api_tool(streaming: bool = False) -> ApiTool:
|
||||
bundle = ApiToolBundle(
|
||||
server_url="https://api.example.com/v1/chat",
|
||||
method="post",
|
||||
summary="test",
|
||||
operation_id="test_op",
|
||||
parameters=[],
|
||||
author="test",
|
||||
icon=None,
|
||||
openapi={},
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
entity = MagicMock(spec=ToolEntity)
|
||||
entity.identity = MagicMock()
|
||||
entity.identity.name = "test_tool"
|
||||
entity.identity.author = "test"
|
||||
entity.output_schema = {}
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.credentials = {"auth_type": "none"}
|
||||
runtime.runtime_parameters = {}
|
||||
|
||||
return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="test_provider")
|
||||
|
||||
|
||||
class FakeStreamResponse:
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
lines: list[str] | None = None,
|
||||
chunks: list[str] | None = None,
|
||||
text: str = "",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.headers = {"content-type": content_type}
|
||||
if extra_headers:
|
||||
self.headers.update(extra_headers)
|
||||
self.text = text
|
||||
self._lines = lines or []
|
||||
self._chunks = chunks or []
|
||||
|
||||
def iter_lines(self):
|
||||
yield from self._lines
|
||||
|
||||
def iter_text(self, chunk_size=4096):
|
||||
yield from self._chunks
|
||||
|
||||
def iter_bytes(self, chunk_size=8192):
|
||||
yield self.text.encode("utf-8")
|
||||
|
||||
def read(self):
|
||||
pass
|
||||
|
||||
|
||||
def _resp(fake: FakeStreamResponse) -> httpx.Response:
|
||||
"""Cast FakeStreamResponse to httpx.Response for type checkers."""
|
||||
return cast(httpx.Response, fake)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _fake_stream_ctx(response: FakeStreamResponse):
|
||||
yield response
|
||||
|
||||
|
||||
class _RaisingContextManager:
|
||||
"""Context manager that raises on __enter__. Used to test exception paths."""
|
||||
|
||||
def __init__(self, exception: BaseException):
|
||||
self.exception = exception
|
||||
|
||||
def __enter__(self):
|
||||
raise self.exception
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ApiToolBundle.streaming default value
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApiToolBundleStreaming:
|
||||
def test_default_false(self):
|
||||
bundle = ApiToolBundle(
|
||||
server_url="https://example.com",
|
||||
method="get",
|
||||
operation_id="op",
|
||||
parameters=[],
|
||||
author="",
|
||||
openapi={},
|
||||
)
|
||||
assert bundle.streaming is False
|
||||
|
||||
def test_explicit_true(self):
|
||||
bundle = ApiToolBundle(
|
||||
server_url="https://example.com",
|
||||
method="get",
|
||||
operation_id="op",
|
||||
parameters=[],
|
||||
author="",
|
||||
openapi={},
|
||||
streaming=True,
|
||||
)
|
||||
assert bundle.streaming is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseSSEStream:
|
||||
def test_openai_format(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = [
|
||||
'data: {"choices":[{"delta":{"content":"Hello"}}]}',
|
||||
'data: {"choices":[{"delta":{"content":" world"}}]}',
|
||||
"data: [DONE]",
|
||||
]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 2
|
||||
assert _msg_text(messages[0]) == "Hello"
|
||||
assert _msg_text(messages[1]) == " world"
|
||||
|
||||
def test_plain_text_sse(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 2
|
||||
assert _msg_text(messages[0]) == "chunk1"
|
||||
assert _msg_text(messages[1]) == "chunk2"
|
||||
|
||||
def test_common_field_names(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = [
|
||||
'data: {"content":"from content"}',
|
||||
'data: {"text":"from text"}',
|
||||
'data: {"message":"from message"}',
|
||||
]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 3
|
||||
assert _msg_text(messages[0]) == "from content"
|
||||
assert _msg_text(messages[1]) == "from text"
|
||||
assert _msg_text(messages[2]) == "from message"
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["", "", "data: hello", ""]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_non_data_lines_skipped(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "actual content"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NDJSON parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseNdjsonStream:
|
||||
def test_json_with_content_field(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})]
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
|
||||
messages = list(tool._parse_ndjson_stream(_resp(resp)))
|
||||
assert len(messages) == 2
|
||||
assert _msg_text(messages[0]) == "line1"
|
||||
|
||||
def test_fallback_to_full_json(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = [json.dumps({"result": 42})]
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
|
||||
messages = list(tool._parse_ndjson_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
assert "42" in _msg_text(messages[0])
|
||||
|
||||
def test_invalid_json_lines(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["not json", ""]
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
|
||||
messages = list(tool._parse_ndjson_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "not json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text stream parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTextStream:
|
||||
def test_text_chunks(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"])
|
||||
messages = list(tool._parse_text_stream(_resp(resp)))
|
||||
assert len(messages) == 3
|
||||
assert _msg_text(messages[0]) == "chunk1"
|
||||
|
||||
def test_empty_chunks_skipped(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""])
|
||||
messages = list(tool._parse_text_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _invoke() streaming branch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInvokeStreamingBranch:
|
||||
@patch.object(ApiTool, "do_http_request_streaming")
|
||||
@patch.object(ApiTool, "do_http_request")
|
||||
def test_streaming_true_uses_streaming_path(self, mock_non_stream, mock_stream):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
mock_stream.return_value = iter([tool.create_text_message("streamed")])
|
||||
|
||||
messages = list(tool._invoke(user_id="u1", tool_parameters={}))
|
||||
mock_stream.assert_called_once()
|
||||
mock_non_stream.assert_not_called()
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "streamed"
|
||||
|
||||
@patch.object(ApiTool, "do_http_request_streaming")
|
||||
@patch.object(ApiTool, "do_http_request")
|
||||
def test_streaming_false_uses_non_streaming_path(self, mock_non_stream, mock_stream):
|
||||
tool = _make_api_tool(streaming=False)
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'{"result": "ok"}'
|
||||
mock_response.headers = {"content-type": "text/plain"}
|
||||
mock_response.text = "ok"
|
||||
mock_response.json.side_effect = Exception("not json")
|
||||
mock_non_stream.return_value = mock_response
|
||||
|
||||
messages = list(tool._invoke(user_id="u1", tool_parameters={}))
|
||||
mock_non_stream.assert_called_once()
|
||||
mock_stream.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingErrorHandling:
|
||||
def test_http_error_raises(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
error_resp = FakeStreamResponse(500, "text/plain", text="Internal Server Error")
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(error_resp)):
|
||||
with pytest.raises(ToolInvokeError, match="500"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
def test_stream_error_raises(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
ctx = _RaisingContextManager(httpx.StreamError("connection reset"))
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
|
||||
with pytest.raises(ToolInvokeError, match="Stream request failed"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
def test_timeout_raises(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out"))
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
|
||||
with pytest.raises(ToolInvokeError, match="Stream request failed"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF Squid proxy detection in streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeHttpxStreamCM:
|
||||
"""Mimics the context manager returned by httpx.Client.stream()."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def __enter__(self):
|
||||
return self.response
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
|
||||
class TestStreamingSSRFProtection:
|
||||
def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.headers = {"content-type": "text/html"}
|
||||
if extra_headers:
|
||||
resp.headers.update(extra_headers)
|
||||
return resp
|
||||
|
||||
def test_squid_403_raises_ssrf_error(self):
|
||||
"""stream_request should detect Squid proxy 403 and raise ToolSSRFError."""
|
||||
fake_resp = self._make_mock_response(403, {"server": "squid/5.7"})
|
||||
mock_client = MagicMock()
|
||||
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
|
||||
|
||||
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
|
||||
with pytest.raises(ToolSSRFError, match="SSRF protection"):
|
||||
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
|
||||
pass
|
||||
|
||||
def test_squid_401_via_header_raises_ssrf_error(self):
|
||||
"""stream_request should detect Squid in Via header and raise ToolSSRFError."""
|
||||
fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"})
|
||||
mock_client = MagicMock()
|
||||
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
|
||||
|
||||
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
|
||||
with pytest.raises(ToolSSRFError, match="SSRF protection"):
|
||||
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
|
||||
pass
|
||||
|
||||
def test_non_squid_403_passes_through(self):
|
||||
"""403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller."""
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(403, "text/plain", text="Forbidden")
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
|
||||
with pytest.raises(ToolInvokeError, match="403"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Content-type routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContentTypeRouting:
|
||||
def test_sse_content_type(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "text/event-stream; charset=utf-8", lines=["data: hello"])
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
|
||||
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_ndjson_content_type(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=[json.dumps({"text": "hi"})])
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
|
||||
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "hi"
|
||||
|
||||
def test_jsonl_content_type(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "application/jsonl", lines=[json.dumps({"content": "hey"})])
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
|
||||
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_fallback_text_stream(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "application/octet-stream", chunks=["raw data"])
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
|
||||
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "raw data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAPI parser: x-dify-streaming extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenAPIParserStreaming:
|
||||
def test_x_dify_streaming_true(self):
|
||||
from flask import Flask
|
||||
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test", "version": "1.0"},
|
||||
"servers": [{"url": "https://api.example.com"}],
|
||||
"paths": {
|
||||
"/chat": {
|
||||
"post": {
|
||||
"x-dify-streaming": True,
|
||||
"operationId": "chat",
|
||||
"summary": "Chat",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
assert len(bundles) == 1
|
||||
assert bundles[0].streaming is True
|
||||
|
||||
def test_x_dify_streaming_absent(self):
|
||||
from flask import Flask
|
||||
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test", "version": "1.0"},
|
||||
"servers": [{"url": "https://api.example.com"}],
|
||||
"paths": {
|
||||
"/query": {
|
||||
"get": {
|
||||
"operationId": "query",
|
||||
"summary": "Query",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
assert len(bundles) == 1
|
||||
assert bundles[0].streaming is False
|
||||
Loading…
Reference in New Issue