From abb0408b082da09388b052595f2fe812ae962eeb Mon Sep 17 00:00:00 2001 From: skhe Date: Tue, 3 Mar 2026 17:01:24 +0800 Subject: [PATCH] fix(tools): resolve Pyrefly type check errors in streaming tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use typing.cast via _resp() helper for FakeStreamResponse → httpx.Response - Replace @contextlib.contextmanager with _RaisingContextManager class to eliminate unreachable yield statements --- .../core/tools/test_api_tool_streaming.py | 55 +++++++++++-------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py index ca76d312f9..30f461cdbf 100644 --- a/api/tests/unit_tests/core/tools/test_api_tool_streaming.py +++ b/api/tests/unit_tests/core/tools/test_api_tool_streaming.py @@ -2,6 +2,7 @@ import contextlib import json +from typing import cast from unittest.mock import MagicMock, patch import httpx @@ -83,11 +84,29 @@ class FakeStreamResponse: pass +def _resp(fake: FakeStreamResponse) -> httpx.Response: + """Cast FakeStreamResponse to httpx.Response for type checkers.""" + return cast(httpx.Response, fake) + + @contextlib.contextmanager def _fake_stream_ctx(response: FakeStreamResponse): yield response +class _RaisingContextManager: + """Context manager that raises on __enter__. Used to test exception paths.""" + + def __init__(self, exception: BaseException): + self.exception = exception + + def __enter__(self): + raise self.exception + + def __exit__(self, *args): + return False + + # --------------------------------------------------------------------------- # ApiToolBundle.streaming default value # --------------------------------------------------------------------------- @@ -132,7 +151,7 @@ class TestParseSSEStream: "data: [DONE]", ] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 2 assert _msg_text(messages[0]) == "Hello" assert _msg_text(messages[1]) == " world" @@ -141,7 +160,7 @@ class TestParseSSEStream: tool = _make_api_tool(streaming=True) lines = ["data: chunk1", "data: chunk2", "", "data: [DONE]"] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 2 assert _msg_text(messages[0]) == "chunk1" assert _msg_text(messages[1]) == "chunk2" @@ -154,7 +173,7 @@ class TestParseSSEStream: 'data: {"message":"from message"}', ] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 3 assert _msg_text(messages[0]) == "from content" assert _msg_text(messages[1]) == "from text" @@ -164,14 +183,14 @@ class TestParseSSEStream: tool = _make_api_tool(streaming=True) lines = ["", "", "data: hello", ""] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 1 def test_non_data_lines_skipped(self): tool = _make_api_tool(streaming=True) lines = ["event: message", "id: 1", "retry: 3000", "data: actual content"] resp = FakeStreamResponse(200, "text/event-stream", lines=lines) - messages = list(tool._parse_sse_stream(resp)) + messages = list(tool._parse_sse_stream(_resp(resp))) assert len(messages) == 1 assert _msg_text(messages[0]) == "actual content" @@ -186,7 +205,7 @@ class TestParseNdjsonStream: tool = _make_api_tool(streaming=True) lines = [json.dumps({"content": "line1"}), json.dumps({"content": "line2"})] resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) - messages = list(tool._parse_ndjson_stream(resp)) + messages = list(tool._parse_ndjson_stream(_resp(resp))) assert len(messages) == 2 assert _msg_text(messages[0]) == "line1" @@ -194,7 +213,7 @@ class TestParseNdjsonStream: tool = _make_api_tool(streaming=True) lines = [json.dumps({"result": 42})] resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) - messages = list(tool._parse_ndjson_stream(resp)) + messages = list(tool._parse_ndjson_stream(_resp(resp))) assert len(messages) == 1 assert "42" in _msg_text(messages[0]) @@ -202,7 +221,7 @@ class TestParseNdjsonStream: tool = _make_api_tool(streaming=True) lines = ["not json", ""] resp = FakeStreamResponse(200, "application/x-ndjson", lines=lines) - messages = list(tool._parse_ndjson_stream(resp)) + messages = list(tool._parse_ndjson_stream(_resp(resp))) assert len(messages) == 1 assert _msg_text(messages[0]) == "not json" @@ -216,14 +235,14 @@ class TestParseTextStream: def test_text_chunks(self): tool = _make_api_tool(streaming=True) resp = FakeStreamResponse(200, "text/plain", chunks=["chunk1", "chunk2", "chunk3"]) - messages = list(tool._parse_text_stream(resp)) + messages = list(tool._parse_text_stream(_resp(resp))) assert len(messages) == 3 assert _msg_text(messages[0]) == "chunk1" def test_empty_chunks_skipped(self): tool = _make_api_tool(streaming=True) resp = FakeStreamResponse(200, "text/plain", chunks=["", "data", ""]) - messages = list(tool._parse_text_stream(resp)) + messages = list(tool._parse_text_stream(_resp(resp))) assert len(messages) == 1 @@ -278,25 +297,17 @@ class TestStreamingErrorHandling: def test_stream_error_raises(self): tool = _make_api_tool(streaming=True) + ctx = _RaisingContextManager(httpx.StreamError("connection reset")) - @contextlib.contextmanager - def _raise_stream_error(): - raise httpx.StreamError("connection reset") - yield - - with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_stream_error()): + with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx): with pytest.raises(ToolInvokeError, match="Stream request failed"): list(tool.do_http_request_streaming("https://example.com", "POST", {}, {})) def test_timeout_raises(self): tool = _make_api_tool(streaming=True) + ctx = _RaisingContextManager(httpx.ReadTimeout("read timed out")) - @contextlib.contextmanager - def _raise_timeout(): - raise httpx.ReadTimeout("read timed out") - yield - - with patch("core.helper.ssrf_proxy.stream_request", return_value=_raise_timeout()): + with patch("core.helper.ssrf_proxy.stream_request", return_value=ctx): with pytest.raises(ToolInvokeError, match="Stream request failed"): list(tool.do_http_request_streaming("https://example.com", "POST", {}, {}))