fix(tools): resolve Pyrefly type check errors in streaming tests

- Use typing.cast via _resp() helper for FakeStreamResponse → httpx.Response
- Replace @contextlib.contextmanager with _RaisingContextManager class
  to eliminate unreachable yield statements
This commit is contained in:
skhe 2026-03-03 17:01:24 +08:00
parent bf0d7fd57c
commit abb0408b08
1 changed files with 33 additions and 22 deletions

View File

@ -2,6 +2,7 @@
import contextlib
import json
from typing import cast
from unittest.mock import MagicMock, patch
import httpx
@ -83,11 +84,29 @@ class FakeStreamResponse:
pass
def _resp(fake: FakeStreamResponse) -> httpx.Response:
"""Cast FakeStreamResponse to httpx.Response for type checkers."""
return cast(httpx.Response, fake)
@contextlib.contextmanager
def _fake_stream_ctx(response: FakeStreamResponse):
yield response
class _RaisingContextManager:
"""Context manager that raises on __enter__. Used to test exception paths."""
def __init__(self, exception: BaseException):
self.exception = exception
def __enter__(self):
raise self.exception
def __exit__(self, *args):
return False
# ---------------------------------------------------------------------------
# ApiToolBundle.streaming default value
# ---------------------------------------------------------------------------
@ -132,7 +151,7 @@ class TestParseSSEStream:
"data: [DONE]",
]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(resp))
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 2
assert _msg_text(messages[0]) == "Hello"
assert _msg_text(messages[1]) == " world"
@ -141,7 +160,7 @@ class TestParseSSEStream:
tool = _make_api_tool(streaming=True)
lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(resp))
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 2
assert _msg_text(messages[0]) == "chunk1"
assert _msg_text(messages[1]) == "chunk2"
@ -154,7 +173,7 @@ class TestParseSSEStream:
'data: {"message":"from message"}',
]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(resp))
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 3
assert _msg_text(messages[0]) == "from content"
assert _msg_text(messages[1]) == "from text"
@ -164,14 +183,14 @@ class TestParseSSEStream:
tool = _make_api_tool(streaming=True)
lines = ["", "", "data: hello", ""]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(resp))
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 1
def test_non_data_lines_skipped(self):
tool = _make_api_tool(streaming=True)
lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"]
resp = FakeStreamResponse(200, "text/event-stream", lines=lines)
messages = list(tool._parse_sse_stream(resp))
messages = list(tool._parse_sse_stream(_resp(resp)))
assert len(messages) == 1
assert _msg_text(messages[0]) == "actual content"
@ -186,7 +205,7 @@ class TestParseNdjsonStream:
tool = _make_api_tool(streaming=True)
lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})]
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
messages = list(tool._parse_ndjson_stream(resp))
messages = list(tool._parse_ndjson_stream(_resp(resp)))
assert len(messages) == 2
assert _msg_text(messages[0]) == "line1"
@ -194,7 +213,7 @@ class TestParseNdjsonStream:
tool = _make_api_tool(streaming=True)
lines = [json.dumps({"result": 42})]
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
messages = list(tool._parse_ndjson_stream(resp))
messages = list(tool._parse_ndjson_stream(_resp(resp)))
assert len(messages) == 1
assert "42" in _msg_text(messages[0])
@ -202,7 +221,7 @@ class TestParseNdjsonStream:
tool = _make_api_tool(streaming=True)
lines = ["not json", ""]
resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines)
messages = list(tool._parse_ndjson_stream(resp))
messages = list(tool._parse_ndjson_stream(_resp(resp)))
assert len(messages) == 1
assert _msg_text(messages[0]) == "not json"
@ -216,14 +235,14 @@ class TestParseTextStream:
def test_text_chunks(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"])
messages = list(tool._parse_text_stream(resp))
messages = list(tool._parse_text_stream(_resp(resp)))
assert len(messages) == 3
assert _msg_text(messages[0]) == "chunk1"
def test_empty_chunks_skipped(self):
tool = _make_api_tool(streaming=True)
resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""])
messages = list(tool._parse_text_stream(resp))
messages = list(tool._parse_text_stream(_resp(resp)))
assert len(messages) == 1
@ -278,25 +297,17 @@ class TestStreamingErrorHandling:
def test_stream_error_raises(self):
tool = _make_api_tool(streaming=True)
ctx = _RaisingContextManager(httpx.StreamError("connection reset"))
@contextlib.contextmanager
def _raise_stream_error():
raise httpx.StreamError("connection reset")
yield
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()):
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
with pytest.raises(ToolInvokeError, match="Stream request failed"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))
def test_timeout_raises(self):
tool = _make_api_tool(streaming=True)
ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out"))
@contextlib.contextmanager
def _raise_timeout():
raise httpx.ReadTimeout("read timed out")
yield
with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()):
with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx):
with pytest.raises(ToolInvokeError, match="Stream request failed"):
list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))