mirror of https://github.com/langgenius/dify.git
fix(tools): resolve Pyrefly type check errors in streaming tests
- Use typing.cast via _resp() helper for FakeStreamResponse → httpx.Response - Replace @contextlib.contextmanager with _RaisingContextManager class to eliminate unreachable yield statements
This commit is contained in:
parent
bf0d7fd57c
commit
abb0408b08
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import contextlib
|
||||
import json
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
|
|
@ -83,11 +84,29 @@ class FakeStreamResponse:
|
|||
pass
|
||||
|
||||
|
||||
def _resp(fake: FakeStreamResponse) -> httpx.Response:
|
||||
"""Cast FakeStreamResponse to httpx.Response for type checkers."""
|
||||
return cast(httpx.Response, fake)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _fake_stream_ctx(response: FakeStreamResponse):
|
||||
yield response
|
||||
|
||||
|
||||
class _RaisingContextManager:
|
||||
"""Context manager that raises on __enter__. Used to test exception paths."""
|
||||
|
||||
def __init__(self, exception: BaseException):
|
||||
self.exception = exception
|
||||
|
||||
def __enter__(self):
|
||||
raise self.exception
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ApiToolBundle.streaming default value
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -132,7 +151,7 @@ class TestParseSSEStream:
|
|||
"data: [DONE]",
|
||||
]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(resp))
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 2
|
||||
assert _msg_text(messages[0]) == "Hello"
|
||||
assert _msg_text(messages[1]) == " world"
|
||||
|
|
@ -141,7 +160,7 @@ class TestParseSSEStream:
|
|||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(resp))
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 2
|
||||
assert _msg_text(messages[0]) == "chunk1"
|
||||
assert _msg_text(messages[1]) == "chunk2"
|
||||
|
|
@ -154,7 +173,7 @@ class TestParseSSEStream:
|
|||
'data: {"message":"from message"}',
|
||||
]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(resp))
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 3
|
||||
assert _msg_text(messages[0]) == "from content"
|
||||
assert _msg_text(messages[1]) == "from text"
|
||||
|
|
@ -164,14 +183,14 @@ class TestParseSSEStream:
|
|||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["", "", "data: hello", ""]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(resp))
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
|
||||
def test_non_data_lines_skipped(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"]
|
||||
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
|
||||
messages = list(tool._parse_sse_stream(resp))
|
||||
messages = list(tool._parse_sse_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "actual content"
|
||||
|
||||
|
|
@ -186,7 +205,7 @@ class TestParseNdjsonStream:
|
|||
tool = _make_api_tool(streaming=True)
|
||||
lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})]
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
|
||||
messages = list(tool._parse_ndjson_stream(resp))
|
||||
messages = list(tool._parse_ndjson_stream(_resp(resp)))
|
||||
assert len(messages) == 2
|
||||
assert _msg_text(messages[0]) == "line1"
|
||||
|
||||
|
|
@ -194,7 +213,7 @@ class TestParseNdjsonStream:
|
|||
tool = _make_api_tool(streaming=True)
|
||||
lines = [json.dumps({"result": 42})]
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
|
||||
messages = list(tool._parse_ndjson_stream(resp))
|
||||
messages = list(tool._parse_ndjson_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
assert "42" in _msg_text(messages[0])
|
||||
|
||||
|
|
@ -202,7 +221,7 @@ class TestParseNdjsonStream:
|
|||
tool = _make_api_tool(streaming=True)
|
||||
lines = ["not json", ""]
|
||||
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
|
||||
messages = list(tool._parse_ndjson_stream(resp))
|
||||
messages = list(tool._parse_ndjson_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
assert _msg_text(messages[0]) == "not json"
|
||||
|
||||
|
|
@ -216,14 +235,14 @@ class TestParseTextStream:
|
|||
def test_text_chunks(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"])
|
||||
messages = list(tool._parse_text_stream(resp))
|
||||
messages = list(tool._parse_text_stream(_resp(resp)))
|
||||
assert len(messages) == 3
|
||||
assert _msg_text(messages[0]) == "chunk1"
|
||||
|
||||
def test_empty_chunks_skipped(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""])
|
||||
messages = list(tool._parse_text_stream(resp))
|
||||
messages = list(tool._parse_text_stream(_resp(resp)))
|
||||
assert len(messages) == 1
|
||||
|
||||
|
||||
|
|
@ -278,25 +297,17 @@ class TestStreamingErrorHandling:
|
|||
|
||||
def test_stream_error_raises(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
ctx = _RaisingContextManager(httpx.StreamError("connection reset"))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _raise_stream_error():
|
||||
raise httpx.StreamError("connection reset")
|
||||
yield
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()):
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
|
||||
with pytest.raises(ToolInvokeError, match="Stream request failed"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
def test_timeout_raises(self):
|
||||
tool = _make_api_tool(streaming=True)
|
||||
ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out"))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _raise_timeout():
|
||||
raise httpx.ReadTimeout("read timed out")
|
||||
yield
|
||||
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()):
|
||||
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
|
||||
with pytest.raises(ToolInvokeError, match="Stream request failed"):
|
||||
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue