From 7c207327e303a6549918f538023a2f285a2a7232 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Tue, 24 Mar 2026 22:28:54 +0800 Subject: [PATCH 1/2] fix(api): StreamsBroadcastChannel start reading messages from the end Setting initial `_last_id` to `0-0` would causing every subscription to receive one copy of the event stream, which is not compatible with the current frontend / backend communication protocol. --- api/libs/broadcast_channel/redis/streams_channel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca..aaeaf76f7b 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -64,7 +64,10 @@ class _StreamsSubscription(Subscription): self._client = client self._key = key self._closed = threading.Event() - self._last_id = "0-0" + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + self._last_id = "$" self._queue: queue.Queue[object] = queue.Queue() self._start_lock = threading.Lock() self._listener: threading.Thread | None = None From b85e11309940c5120a7cd79cb191aa771a597553 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Tue, 24 Mar 2026 22:38:16 +0800 Subject: [PATCH 2/2] test(api): add unit tests and integration tests for StreamsBroadcastChannel --- .../redis/test_streams_channel.py | 227 +++++++++++++++++ .../redis/test_streams_channel_unit_tests.py | 228 +++++++++++++++++- 2 files changed, 445 insertions(+), 10 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py new file mode 100644 index 0000000000..a79208f649 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py @@ -0,0 +1,227 @@ +""" +Integration tests for Redis Streams broadcast channel implementation using TestContainers. + +This suite focuses on the semantics that differ from Redis Pub/Sub: +- Every active subscription should receive each newly published message. +- Each subscription should only observe messages published after its listener starts. +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + +class TestRedisStreamsBroadcastChannelIntegration: + """Integration tests for Redis Streams broadcast channel with a real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a StreamsBroadcastChannel instance with a real Redis client.""" + return StreamsBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_streams_topic_{uuid.uuid4()}" + + @staticmethod + def _start_subscription(subscription: Subscription) -> None: + """Start the background listener and confirm the subscription queue is empty.""" + assert subscription.receive(timeout=0.05) is None + + @staticmethod + def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes: + """Poll until a message is received or the timeout expires.""" + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + message = subscription.receive(timeout=0.1) + if message is not None: + return message + pytest.fail("Timed out waiting for a message") + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None: + """Closing an active subscription should terminate the iterator cleanly.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume() -> list[bytes]: + messages: list[bytes] = [] + consuming_event.set() + for message in subscription: + messages.append(message) + return messages + + with ThreadPoolExecutor(max_workers=1) as executor: + consumer_future = executor.submit(consume) + assert consuming_event.wait(timeout=1.0) + subscription.close() + assert consumer_future.result(timeout=2.0) == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None: + """A producer should publish a message that a live subscription can consume.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + producer = topic.as_producer() + subscription = topic.subscribe() + message = b"hello streams" + + try: + self._start_subscription(subscription) + producer.publish(message) + + assert self._receive_message(subscription) == message + assert subscription.receive(timeout=0.1) is None + finally: + subscription.close() + + def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None: + """Each active subscription should receive the same newly published message.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscriptions = [topic.subscribe() for _ in range(3)] + new_message = b"message-visible-to-every-subscriber" + + try: + for subscription in subscriptions: + self._start_subscription(subscription) + + topic.publish(new_message) + + for subscription in subscriptions: + assert self._receive_message(subscription) == new_message + assert subscription.receive(timeout=0.1) is None + finally: + for subscription in subscriptions: + subscription.close() + + def test_each_subscription_only_receives_messages_published_after_it_starts( + self, + broadcast_channel: BroadcastChannel, + ) -> None: + """A late subscription should not replay messages that existed before its listener started.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + first_subscription = topic.subscribe() + second_subscription = topic.subscribe() + message_before_any_subscription = b"before-any-subscription" + message_after_first_subscription = b"after-first-subscription" + message_after_second_subscription = b"after-second-subscription" + + try: + topic.publish(message_before_any_subscription) + + self._start_subscription(first_subscription) + topic.publish(message_after_first_subscription) + + assert self._receive_message(first_subscription) == message_after_first_subscription + assert first_subscription.receive(timeout=0.1) is None + + self._start_subscription(second_subscription) + topic.publish(message_after_second_subscription) + + assert self._receive_message(first_subscription) == message_after_second_subscription + assert self._receive_message(second_subscription) == message_after_second_subscription + assert first_subscription.receive(timeout=0.1) is None + assert second_subscription.receive(timeout=0.1) is None + finally: + first_subscription.close() + second_subscription.close() + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None: + """Messages from different topics should remain isolated.""" + topic1 = broadcast_channel.topic(self._get_test_topic_name()) + topic2 = broadcast_channel.topic(self._get_test_topic_name()) + message1 = b"message-for-topic-1" + message2 = b"message-for-topic-2" + + def consume_single_message(topic: Topic) -> bytes: + subscription = topic.subscribe() + try: + self._start_subscription(subscription) + return self._receive_message(subscription) + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + consumer1_future = executor.submit(consume_single_message, topic1) + consumer2_future = executor.submit(consume_single_message, topic2) + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + assert consumer1_future.result(timeout=5.0) == message1 + assert consumer2_future.result(timeout=5.0) == message2 + + def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None: + """Concurrent producers should not lose messages for a live subscription.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + producer_count = 4 + messages_per_producer = 4 + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def produce_messages(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced: set[bytes] = set() + for message_idx in range(messages_per_producer): + payload = f"producer-{producer_idx}-message-{message_idx}".encode() + produced.add(payload) + producer.publish(payload) + time.sleep(0.001) + return produced + + def consume_messages() -> set[bytes]: + received: set[bytes] = set() + try: + self._start_subscription(subscription) + consumer_ready.set() + while len(received) < expected_total: + message = subscription.receive(timeout=0.2) + if message is not None: + received.add(message) + return received + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consume_messages) + assert consumer_ready.wait(timeout=2.0) + + producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)] + expected_messages: set[bytes] = set() + for future in as_completed(producer_futures, timeout=10.0): + expected_messages.update(future.result()) + + assert consumer_future.result(timeout=10.0) == expected_messages + + def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None: + """Calling receive on a closed subscription should raise SubscriptionClosedError.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + + self._start_subscription(subscription) + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.1) diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index 248aa0b145..bf548f69cf 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -1,7 +1,11 @@ +import threading import time +from dataclasses import dataclass +from typing import cast import pytest +from libs.broadcast_channel.exc import SubscriptionClosedError from libs.broadcast_channel.redis.streams_channel import ( StreamsBroadcastChannel, StreamsTopic, @@ -22,6 +26,7 @@ class FakeStreamsRedis: self._store: dict[str, list[tuple[str, dict]]] = {} self._next_id: dict[str, int] = {} self._expire_calls: dict[str, int] = {} + self._dollar_snapshots: dict[str, int] = {} # Publisher API def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: @@ -47,7 +52,9 @@ class FakeStreamsRedis: # Find position strictly greater than last_id start_idx = 0 - if last_id != "0-0": + if last_id == "$": + start_idx = self._dollar_snapshots.setdefault(key, len(entries)) + elif last_id != "0-0": for i, (eid, _f) in enumerate(entries): if eid == last_id: start_idx = i + 1 @@ -63,6 +70,55 @@ class FakeStreamsRedis: return [(key, batch)] +class FailExpireRedis(FakeStreamsRedis): + def expire(self, key: str, seconds: int) -> None: + raise RuntimeError("expire failed") + + +class BlockingRedis: + def __init__(self) -> None: + self._release = threading.Event() + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + self._release.wait(timeout=block / 1000.0 if block else None) + return [] + + def release(self) -> None: + self._release.set() + + +@dataclass(frozen=True) +class ListenPayloadCase: + name: str + fields: object + expected_messages: list[bytes] + + +def build_listen_payload_cases() -> list[ListenPayloadCase]: + return [ + ListenPayloadCase( + name="string_payload_is_encoded", + fields={b"data": "hello"}, + expected_messages=[b"hello"], + ), + ListenPayloadCase( + name="bytearray_payload_is_converted", + fields={b"data": bytearray(b"world")}, + expected_messages=[b"world"], + ), + ListenPayloadCase( + name="non_dict_fields_are_ignored", + fields=[("data", b"ignored")], + expected_messages=[], + ), + ListenPayloadCase( + name="missing_payload_is_ignored", + fields={b"other": b"ignored"}, + expected_messages=[], + ), + ] + + @pytest.fixture def fake_redis() -> FakeStreamsRedis: return FakeStreamsRedis() @@ -94,21 +150,37 @@ class TestStreamsBroadcastChannel: # Expire called after publish assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("producer-subscriber") + + assert topic.as_producer() is topic + assert topic.as_subscriber() is topic + + def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture): + channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60) + topic = channel.topic("expire-warning") + + topic.publish(b"payload") + + assert "Failed to set expire for stream key" in caplog.text + class TestStreamsSubscription: - def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + def test_subscribe_only_receives_messages_published_after_subscription_starts( + self, + streams_channel: StreamsBroadcastChannel, + ): topic = streams_channel.topic("gamma") - # Pre-publish events before subscribing (late subscriber) - topic.publish(b"e1") - topic.publish(b"e2") + topic.publish(b"before-subscribe") sub = topic.subscribe() assert isinstance(sub, _StreamsSubscription) received: list[bytes] = [] with sub: - # Give listener thread a moment to xread - time.sleep(0.05) + assert sub.receive(timeout=0.05) is None + topic.publish(b"after-subscribe-1") + topic.publish(b"after-subscribe-2") # Drain using receive() to avoid indefinite iteration in tests for _ in range(5): msg = sub.receive(timeout=0.1) @@ -116,7 +188,7 @@ class TestStreamsSubscription: break received.append(msg) - assert received == [b"e1", b"e2"] + assert received == [b"after-subscribe-1", b"after-subscribe-2"] def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("delta") @@ -132,8 +204,6 @@ class TestStreamsSubscription: # Listener running; now close and ensure no crash sub.close() # After close, receive should raise SubscriptionClosedError - from libs.broadcast_channel.exc import SubscriptionClosedError - with pytest.raises(SubscriptionClosedError): sub.receive() @@ -143,3 +213,141 @@ class TestStreamsSubscription: topic.publish(b"payload") # No expire recorded when retention is disabled assert fake_redis._expire_calls.get("stream:zeta") is None + + @pytest.mark.parametrize( + ("case"), + build_listen_payload_cases(), + ids=lambda case: cast(ListenPayloadCase, case).name, + ) + def test_listener_normalizes_supported_payloads_and_ignores_unsupported_shapes(self, case: ListenPayloadCase): + class OneShotRedis: + def __init__(self, fields: object) -> None: + self._fields = fields + self._calls = 0 + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + self._calls += 1 + if self._calls == 1: + key = next(iter(streams)) + return [(key, [("1-0", self._fields)])] + subscription._closed.set() + return [] + + subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape") + subscription._listen() + + received: list[bytes] = [] + while not subscription._queue.empty(): + item = subscription._queue.get_nowait() + if item is subscription._SENTINEL: + break + received.append(bytes(item)) + + assert received == case.expected_messages + assert subscription._last_id == "1-0" + + def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("iter") + subscription = topic.subscribe() + iterator = iter(subscription) + + def publish_later() -> None: + time.sleep(0.05) + topic.publish(b"iter-message") + + publisher = threading.Thread(target=publish_later, daemon=True) + publisher.start() + + assert next(iterator) == b"iter-message" + + subscription.close() + publisher.join(timeout=1) + with pytest.raises(StopIteration): + next(iterator) + + def test_receive_with_none_timeout_blocks_until_message_arrives(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("blocking") + subscription = topic.subscribe() + + def publish_later() -> None: + time.sleep(0.05) + topic.publish(b"blocking-message") + + publisher = threading.Thread(target=publish_later, daemon=True) + publisher.start() + + try: + assert subscription.receive(timeout=None) == b"blocking-message" + finally: + subscription.close() + publisher.join(timeout=1) + + def test_receive_raises_when_queue_contains_close_sentinel(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:sentinel") + subscription._listener = threading.current_thread() + subscription._queue.put_nowait(subscription._SENTINEL) + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.01) + + def test_close_before_listener_starts_is_a_noop(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:not-started") + + subscription.close() + + assert subscription._listener is None + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.01) + + def test_start_if_needed_returns_immediately_for_closed_subscription(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed") + subscription._closed.set() + + subscription._start_if_needed() + + assert subscription._listener is None + + def test_iterator_skips_none_results_and_keeps_polling(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:iterator-none") + items = iter([None, b"event"]) + + subscription._start_if_needed = lambda: None # type: ignore[method-assign] + + def fake_receive(timeout: float | None = 0.1) -> bytes | None: + value = next(items) + if value is not None: + subscription._closed.set() + return value + + subscription.receive = fake_receive # type: ignore[method-assign] + + assert next(iter(subscription)) == b"event" + + def test_close_logs_warning_when_listener_does_not_stop_in_time( + self, + caplog: pytest.LogCaptureFixture, + ): + blocking_redis = BlockingRedis() + subscription = _StreamsSubscription(blocking_redis, "stream:slow-close") + + subscription._start_if_needed() + listener = subscription._listener + assert listener is not None + + original_join = listener.join + original_is_alive = listener.is_alive + + def delayed_join(timeout: float | None = None) -> None: + original_join(0.01) + + listener.join = delayed_join # type: ignore[method-assign] + listener.is_alive = lambda: True # type: ignore[method-assign] + + try: + subscription.close() + assert "did not stop within timeout" in caplog.text + finally: + listener.join = original_join # type: ignore[method-assign] + listener.is_alive = original_is_alive # type: ignore[method-assign] + blocking_redis.release() + original_join(timeout=1)