This commit is contained in:
skhe 2026-03-24 10:39:31 +08:00 committed by GitHub
commit 8bbbb90df9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 690 additions and 9 deletions

View File

@ -2,8 +2,10 @@
Proxy requests to avoid SSRF
"""
import contextlib
import logging
import time
from collections.abc import Iterator
from typing import Any, TypeAlias
import httpx
@ -268,3 +270,54 @@ class SSRFProxy:
ssrf_proxy = SSRFProxy()
@contextlib.contextmanager
def stream_request(method: str, url: str, **kwargs: Any) -> Iterator[httpx.Response]:
"""Streaming HTTP request context manager with SSRF protection.
Unlike make_request(), this does not implement retry logic because retries
are meaningless for streaming responses -- the stream cannot be replayed.
"""
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
kwargs["follow_redirects"] = allow_redirects
if "timeout" not in kwargs:
kwargs["timeout"] = httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
)
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
if not isinstance(verify_option, bool):
raise ValueError("ssl_verify must be a boolean")
client = _get_ssrf_client(verify_option)
try:
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
except ValidationError as e:
raise ValueError("headers must be a mapping of string keys to string values") from e
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
user_provided_host = _get_user_provided_host_header(headers)
if user_provided_host is not None:
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
headers["host"] = user_provided_host
kwargs["headers"] = headers
with client.stream(method=method, url=url, **kwargs) as response:
# Check for SSRF protection by Squid proxy (same logic as make_request)
if response.status_code in (401, 403):
server_header = response.headers.get("server", "").lower()
via_header = response.headers.get("via", "").lower()
if "squid" in server_header or "squid" in via_header:
raise ToolSSRFError(
f"Access to '{url}' was blocked by SSRF protection. "
f"The URL may point to a private or local network address. "
)
yield response

View File

@ -1,4 +1,5 @@
import json
import logging
from collections.abc import Generator
from dataclasses import dataclass
from os import getenv
@ -15,6 +16,8 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from dify_graph.file.file_manager import download
logger = logging.getLogger(__name__)
API_TOOL_DEFAULT_TIMEOUT = (
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
@ -168,14 +171,15 @@ class ApiTool(Tool):
else:
return (parameter.get("schema", {}) or {}).get("default", "")
def do_http_request(
def _prepare_request_parts(
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
) -> httpx.Response:
) -> tuple[dict, Any, dict, list, str, dict[str, Any]]:
"""
do http request depending on api bundle
"""
method = method.lower()
Assemble request parts (params, body, cookies, files, url, headers) from
the OpenAPI bundle and supplied parameters.
This is shared by both the streaming and non-streaming request paths.
"""
params = {}
path_params = {}
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
@ -275,6 +279,24 @@ class ApiTool(Tool):
if files:
headers.pop("Content-Type", None)
return params, body, cookies, files, url, headers
def do_http_request(
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
) -> httpx.Response:
"""
Do a non-streaming HTTP request depending on api bundle.
For streaming requests, see do_http_request_streaming().
Both methods share _prepare_request_parts() for request assembly.
"""
params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters)
_VALID_METHODS = {"get", "head", "post", "put", "delete", "patch"}
method_lc = method.lower()
if method_lc not in _VALID_METHODS:
raise ValueError(f"Invalid http method {method}")
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
@ -283,9 +305,6 @@ class ApiTool(Tool):
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = method.lower()
if method_lc not in _METHOD_MAP:
raise ValueError(f"Invalid http method {method}")
response: httpx.Response = _METHOD_MAP[
method_lc
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
@ -301,6 +320,129 @@ class ApiTool(Tool):
)
return response
def do_http_request_streaming(
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
) -> Generator[ToolInvokeMessage, None, None]:
"""Streaming HTTP request, yields ToolInvokeMessage as response chunks arrive."""
params, body, cookies, files, url, headers = self._prepare_request_parts(url, method, headers, parameters)
try:
with ssrf_proxy.stream_request(
method=method.upper(),
url=url,
params=params,
headers=headers,
cookies=cookies,
data=body,
files=files or None,
timeout=API_TOOL_DEFAULT_TIMEOUT,
follow_redirects=True,
) as response:
if response.status_code >= 400:
# Read a bounded amount to avoid unbounded memory usage on large error bodies
_MAX_ERROR_BYTES = 1_048_576 # 1 MB
error_body = b""
for chunk in response.iter_bytes(chunk_size=8192):
error_body += chunk
if len(error_body) >= _MAX_ERROR_BYTES:
error_body = error_body[:_MAX_ERROR_BYTES]
break
error_text = error_body.decode("utf-8", errors="replace")
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {error_text}")
content_type = response.headers.get("content-type", "").lower()
if "text/event-stream" in content_type:
yield from self._parse_sse_stream(response)
elif "application/x-ndjson" in content_type or "application/jsonl" in content_type:
yield from self._parse_ndjson_stream(response)
else:
yield from self._parse_text_stream(response)
except (httpx.StreamError, httpx.TimeoutException) as e:
raise ToolInvokeError(f"Stream request failed: {e}")
def _parse_sse_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]:
"""Parse Server-Sent Events stream, yielding text messages."""
for line in response.iter_lines():
if not line:
continue
# Handle data: prefix
if line.startswith("data:"):
data = line[5:].strip()
# Handle [DONE] terminator (OpenAI convention)
if data == "[DONE]":
return
# Try to parse as JSON and extract content
text = self._extract_text_from_sse_data(data)
if text:
yield self.create_text_message(text)
def _extract_text_from_sse_data(self, data: str) -> str:
"""Extract text content from SSE data line. Supports OpenAI format and common field names."""
try:
obj = json.loads(data)
except (json.JSONDecodeError, ValueError):
logger.debug("SSE data is not valid JSON, treating as plain text: %s", data[:200])
return data
if not isinstance(obj, dict):
return data
# OpenAI chat completion format: choices[].delta.content
choices = obj.get("choices")
if isinstance(choices, list) and len(choices) > 0:
delta = choices[0].get("delta", {})
if isinstance(delta, dict):
content = delta.get("content")
if content:
return str(content)
# Non-delta format: choices[].text (completion API)
text = choices[0].get("text")
if text:
return str(text)
# Common field names
for field in ("content", "text", "message", "data"):
value = obj.get(field)
if value and isinstance(value, str):
return value
# Fallback: return raw data
return data
def _parse_ndjson_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]:
"""Parse newline-delimited JSON stream."""
for line in response.iter_lines():
if not line:
continue
try:
obj = json.loads(line)
if isinstance(obj, dict):
# Try common text fields
for field in ("content", "text", "message", "data"):
value = obj.get(field)
if value and isinstance(value, str):
yield self.create_text_message(value)
break
else:
yield self.create_text_message(json.dumps(obj, ensure_ascii=False))
else:
yield self.create_text_message(str(obj))
except (json.JSONDecodeError, ValueError):
logger.debug("NDJSON line is not valid JSON, treating as plain text: %s", line[:200])
if line.strip():
yield self.create_text_message(line)
def _parse_text_stream(self, response: httpx.Response) -> Generator[ToolInvokeMessage, None, None]:
"""Parse plain text stream in chunks."""
for chunk in response.iter_text(chunk_size=4096):
if chunk:
yield self.create_text_message(chunk)
def _convert_body_property_any_of(
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
):
@ -384,10 +526,18 @@ class ApiTool(Tool):
"""
invoke http request
"""
response: httpx.Response | str = ""
# assemble request
headers = self.assembling_request(tool_parameters)
# streaming path
if self.api_bundle.streaming:
yield from self.do_http_request_streaming(
self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters
)
return
# non-streaming path (original)
response: httpx.Response | str = ""
# do http request
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)

View File

@ -29,3 +29,5 @@ class ApiToolBundle(BaseModel):
openapi: dict
# output schema
output_schema: Mapping[str, object] = Field(default_factory=dict)
# whether this operation supports streaming response
streaming: bool = Field(default=False)

View File

@ -202,6 +202,9 @@ class ApiBasedToolSchemaParser:
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
# Extract x-dify-streaming extension field
streaming = bool(interface["operation"].get("x-dify-streaming", False))
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
@ -214,6 +217,7 @@ class ApiBasedToolSchemaParser:
author="",
icon=None,
openapi=interface["operation"],
streaming=streaming,
)
)

View File

@ -0,0 +1,472 @@
"""Tests for ApiTool streaming support."""
import contextlib
import json
from typing import cast
from unittest.mock import MagicMock, patch
import httpx
import pytest
from core.helper import ssrf_proxy
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage
from core.tools.errors import ToolInvokeError, ToolSSRFError
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
def _msg_text(msg: ToolInvokeMessage) -> str:
"""Extract plain text from a ToolInvokeMessage regardless of inner type."""
inner = msg.message
if hasattr(inner, "text"):
return inner.text
return str(inner)
def _make_api_tool(streaming: bool = False) -> ApiTool:
bundle = ApiToolBundle(
server_url="https://api.example.com/v1/chat",
method="post",
summary="test",
operation_id="test_op",
parameters=[],
author="test",
icon=None,
openapi={},
streaming=streaming,
)
entity = MagicMock(spec=ToolEntity)
entity.identity = MagicMock()
entity.identity.name = "test_tool"
entity.identity.author = "test"
entity.output_schema = {}
runtime = MagicMock()
runtime.credentials = {"auth_type": "none"}
runtime.runtime_parameters = {}
return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="test_provider")
class FakeStreamResponse:
def __init__(
self,
status_code: int,
content_type: str,
lines: list[str] | None = None,
chunks: list[str] | None = None,
text: str = "",
extra_headers: dict[str, str] | None = None,
):
self.status_code = status_code
self.headers = {"content-type": content_type}
if extra_headers:
self.headers.update(extra_headers)
self.text = text
self._lines = lines or []
self._chunks = chunks or []
def iter_lines(self):
yield from self._lines
def iter_text(self, chunk_size=4096):
yield from self._chunks
def iter_bytes(self, chunk_size=8192):
yield self.text.encode("utf-8")
def read(self):
pass
def _resp(fake: FakeStreamResponse) -> httpx.Response:
"""Cast FakeStreamResponse to httpx.Response for type checkers."""
return cast(httpx.Response, fake)
@contextlib.contextmanager
def _fake_stream_ctx(response: FakeStreamResponse):
yield response
class _RaisingContextManager:
"""Context manager that raises on __enter__. Used to test exception paths."""
def __init__(self, exception: BaseException):
self.exception = exception
def __enter__(self):
raise self.exception
def __exit__(self, *args):
return False
# ---------------------------------------------------------------------------
# ApiToolBundle.streaming default value
# ---------------------------------------------------------------------------
class TestApiToolBundleStreaming:
def test_default_false(self):
bundle = ApiToolBundle(
server_url="https://example.com",
method="get",
operation_id="op",
parameters=[],
author="",
openapi={},
)
assert bundle.streaming is False
def test_explicit_true(self):
bundle = ApiToolBundle(
server_url="https://example.com",
method="get",
operation_id="op",
parameters=[],
author="",
openapi={},
streaming=True,
)
assert bundle.streaming is True
# ---------------------------------------------------------------------------
# SSE parser
# ---------------------------------------------------------------------------
class TestParseSSEStream:
def test_openai_format(self):
tool = _make_api_tool(streaming=True)
lines = [
'data: {"choices":[{"delta":{"content":"Hello"}}]}',
'data: {"choices":[{"delta":{"content":" world"}}]}',
"data: [DONE]",
]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 2
assert _msg_text(messages[0]) == "Hello"
assert _msg_text(messages[1]) == " world"
def test_plain_text_sse(self):
tool = _make_api_tool(streaming=True)
lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 2
assert _msg_text(messages[0]) == "chunk1"
assert _msg_text(messages[1]) == "chunk2"
def test_common_field_names(self):
tool = _make_api_tool(streaming=True)
lines = [
'data: {"content":"from content"}',
'data: {"text":"from text"}',
'data: {"message":"from message"}',
]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 3
assert _msg_text(messages[0]) == "from content"
assert _msg_text(messages[1]) == "from text"
assert _msg_text(messages[2]) == "from message"
def test_empty_lines_skipped(self):
tool = _make_api_tool(streaming=True)
lines = ["", "", "data: hello", ""]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 1
def test_non_data_lines_skipped(self):
tool = _make_api_tool(streaming=True)
lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 1
assert _msg_text(messages[0]) == "actual content"
# ---------------------------------------------------------------------------
# NDJSON parser
# ---------------------------------------------------------------------------
class TestParseNdjsonStream:
def test_json_with_content_field(self):
tool = _make_api_tool(streaming=True)
lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})]
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
messages = list(tool._parse_ndjson_stream(_resp(resp)))
assert len(messages) == 2
assert _msg_text(messages[0]) == "line1"
def test_fallback_to_full_json(self):
tool = _make_api_tool(streaming=True)
lines = [json.dumps({"result": 42})]
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
messages = list(tool._parse_ndjson_stream(_resp(resp)))
assert len(messages) == 1
assert "42" in _msg_text(messages[0])
def test_invalid_json_lines(self):
tool = _make_api_tool(streaming=True)
lines = ["not json", ""]
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
messages = list(tool._parse_ndjson_stream(_resp(resp)))
assert len(messages) == 1
assert _msg_text(messages[0]) == "not json"
# ---------------------------------------------------------------------------
# Text stream parser
# ---------------------------------------------------------------------------
class TestParseTextStream:
def test_text_chunks(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"])
messages = list(tool._parse_text_stream(_resp(resp)))
assert len(messages) == 3
assert _msg_text(messages[0]) == "chunk1"
def test_empty_chunks_skipped(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""])
messages = list(tool._parse_text_stream(_resp(resp)))
assert len(messages) == 1
# ---------------------------------------------------------------------------
# _invoke() streaming branch
# ---------------------------------------------------------------------------
class TestInvokeStreamingBranch:
@patch.object(ApiTool, "do_http_request_streaming")
@patch.object(ApiTool, "do_http_request")
def test_streaming_true_uses_streaming_path(self, mock_non_stream, mock_stream):
tool = _make_api_tool(streaming=True)
mock_stream.return_value = iter([tool.create_text_message("streamed")])
messages = list(tool._invoke(user_id="u1", tool_parameters={}))
mock_stream.assert_called_once()
mock_non_stream.assert_not_called()
assert len(messages) == 1
assert _msg_text(messages[0]) == "streamed"
@patch.object(ApiTool, "do_http_request_streaming")
@patch.object(ApiTool, "do_http_request")
def test_streaming_false_uses_non_streaming_path(self, mock_non_stream, mock_stream):
tool = _make_api_tool(streaming=False)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.content = b'{"result": "ok"}'
mock_response.headers = {"content-type": "text/plain"}
mock_response.text = "ok"
mock_response.json.side_effect = Exception("not json")
mock_non_stream.return_value = mock_response
messages = list(tool._invoke(user_id="u1", tool_parameters={}))
mock_non_stream.assert_called_once()
mock_stream.assert_not_called()
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
class TestStreamingErrorHandling:
def test_http_error_raises(self):
tool = _make_api_tool(streaming=True)
error_resp = FakeStreamResponse(500, "text/plain", text="Internal Server Error")
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(error_resp)):
with pytest.raises(ToolInvokeError, match="500"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
def test_stream_error_raises(self):
tool = _make_api_tool(streaming=True)
ctx = _RaisingContextManager(httpx.StreamError("connection reset"))
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
with pytest.raises(ToolInvokeError, match="Stream request failed"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
def test_timeout_raises(self):
tool = _make_api_tool(streaming=True)
ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out"))
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
with pytest.raises(ToolInvokeError, match="Stream request failed"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
# ---------------------------------------------------------------------------
# SSRF Squid proxy detection in streaming
# ---------------------------------------------------------------------------
class _FakeHttpxStreamCM:
"""Mimics the context manager returned by httpx.Client.stream()."""
def __init__(self, response):
self.response = response
def __enter__(self):
return self.response
def __exit__(self, *args):
return False
class TestStreamingSSRFProtection:
def _make_mock_response(self, status_code: int, extra_headers: dict[str, str] | None = None):
resp = MagicMock()
resp.status_code = status_code
resp.headers = {"content-type": "text/html"}
if extra_headers:
resp.headers.update(extra_headers)
return resp
def test_squid_403_raises_ssrf_error(self):
"""stream_request should detect Squid proxy 403 and raise ToolSSRFError."""
fake_resp = self._make_mock_response(403, {"server": "squid/5.7"})
mock_client = MagicMock()
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
with pytest.raises(ToolSSRFError, match="SSRF protection"):
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
pass
def test_squid_401_via_header_raises_ssrf_error(self):
"""stream_request should detect Squid in Via header and raise ToolSSRFError."""
fake_resp = self._make_mock_response(401, {"via": "1.1 squid-proxy"})
mock_client = MagicMock()
mock_client.stream.return_value = _FakeHttpxStreamCM(fake_resp)
with patch("core.helper.ssrf_proxy._get_ssrf_client", return_value=mock_client):
with pytest.raises(ToolSSRFError, match="SSRF protection"):
with ssrf_proxy.stream_request("POST", "https://internal.example.com"):
pass
def test_non_squid_403_passes_through(self):
"""403 from non-Squid server should NOT raise ToolSSRFError but be handled by caller."""
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(403, "text/plain", text="Forbidden")
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
with pytest.raises(ToolInvokeError, match="403"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
# ---------------------------------------------------------------------------
# Content-type routing
# ---------------------------------------------------------------------------
class TestContentTypeRouting:
def test_sse_content_type(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "text/event-stream; charset=utf-8", lines=["data: hello"])
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
assert len(messages) == 1
def test_ndjson_content_type(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "application/x-ndjson", lines=[json.dumps({"text": "hi"})])
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
assert len(messages) == 1
assert _msg_text(messages[0]) == "hi"
def test_jsonl_content_type(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "application/jsonl", lines=[json.dumps({"content": "hey"})])
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
assert len(messages) == 1
def test_fallback_text_stream(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "application/octet-stream", chunks=["raw data"])
with patch("core.helper.ssrf_proxy.stream_request", return_value=_fake_stream_ctx(resp)):
messages = list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
assert len(messages) == 1
assert _msg_text(messages[0]) == "raw data"
# ---------------------------------------------------------------------------
# OpenAPI parser: x-dify-streaming extraction
# ---------------------------------------------------------------------------
class TestOpenAPIParserStreaming:
def test_x_dify_streaming_true(self):
from flask import Flask
from core.tools.utils.parser import ApiBasedToolSchemaParser
openapi = {
"openapi": "3.0.0",
"info": {"title": "Test", "version": "1.0"},
"servers": [{"url": "https://api.example.com"}],
"paths": {
"/chat": {
"post": {
"x-dify-streaming": True,
"operationId": "chat",
"summary": "Chat",
"responses": {"200": {"description": "OK"}},
}
}
},
}
app = Flask(__name__)
with app.test_request_context():
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
assert len(bundles) == 1
assert bundles[0].streaming is True
def test_x_dify_streaming_absent(self):
from flask import Flask
from core.tools.utils.parser import ApiBasedToolSchemaParser
openapi = {
"openapi": "3.0.0",
"info": {"title": "Test", "version": "1.0"},
"servers": [{"url": "https://api.example.com"}],
"paths": {
"/query": {
"get": {
"operationId": "query",
"summary": "Query",
"responses": {"200": {"description": "OK"}},
}
}
},
}
app = Flask(__name__)
with app.test_request_context():
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
assert len(bundles) == 1
assert bundles[0].streaming is False