mirror of https://github.com/langgenius/dify.git
Merge 27443c219a into b15d312f68
This commit is contained in:
commit
5934ea4c30
|
|
@ -64,7 +64,10 @@ class _StreamsSubscription(Subscription):
|
|||
self._client = client
|
||||
self._key = key
|
||||
self._closed = threading.Event()
|
||||
self._last_id = "0-0"
|
||||
# Setting initial last id to `$` to signal redis that we only want new messages.
|
||||
#
|
||||
# ref: https://redis.io/docs/latest/commands/xread/#the-special--id
|
||||
self._last_id = "$"
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
self._start_lock = threading.Lock()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,227 @@
|
|||
"""
|
||||
Integration tests for Redis Streams broadcast channel implementation using TestContainers.
|
||||
|
||||
This suite focuses on the semantics that differ from Redis Pub/Sub:
|
||||
- Every active subscription should receive each newly published message.
|
||||
- Each subscription should only observe messages published after its listener starts.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
|
||||
class TestRedisStreamsBroadcastChannelIntegration:
|
||||
"""Integration tests for Redis Streams broadcast channel with a real Redis instance."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_container(self) -> Iterator[RedisContainer]:
|
||||
"""Create a Redis container for integration testing."""
|
||||
with RedisContainer(image="redis:6-alpine") as container:
|
||||
yield container
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
|
||||
"""Create a Redis client connected to the test container."""
|
||||
host = redis_container.get_container_host_ip()
|
||||
port = redis_container.get_exposed_port(6379)
|
||||
return redis.Redis(host=host, port=port, decode_responses=False)
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
|
||||
"""Create a StreamsBroadcastChannel instance with a real Redis client."""
|
||||
return StreamsBroadcastChannel(redis_client)
|
||||
|
||||
@classmethod
|
||||
def _get_test_topic_name(cls) -> str:
|
||||
return f"test_streams_topic_{uuid.uuid4()}"
|
||||
|
||||
@staticmethod
|
||||
def _start_subscription(subscription: Subscription) -> None:
|
||||
"""Start the background listener and confirm the subscription queue is empty."""
|
||||
assert subscription.receive(timeout=0.05) is None
|
||||
|
||||
@staticmethod
|
||||
def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes:
|
||||
"""Poll until a message is received or the timeout expires."""
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
message = subscription.receive(timeout=0.1)
|
||||
if message is not None:
|
||||
return message
|
||||
pytest.fail("Timed out waiting for a message")
|
||||
|
||||
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Closing an active subscription should terminate the iterator cleanly."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscription = topic.subscribe()
|
||||
consuming_event = threading.Event()
|
||||
|
||||
def consume() -> list[bytes]:
|
||||
messages: list[bytes] = []
|
||||
consuming_event.set()
|
||||
for message in subscription:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
consumer_future = executor.submit(consume)
|
||||
assert consuming_event.wait(timeout=1.0)
|
||||
subscription.close()
|
||||
assert consumer_future.result(timeout=2.0) == []
|
||||
|
||||
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""A producer should publish a message that a live subscription can consume."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
producer = topic.as_producer()
|
||||
subscription = topic.subscribe()
|
||||
message = b"hello streams"
|
||||
|
||||
try:
|
||||
self._start_subscription(subscription)
|
||||
producer.publish(message)
|
||||
|
||||
assert self._receive_message(subscription) == message
|
||||
assert subscription.receive(timeout=0.1) is None
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Each active subscription should receive the same newly published message."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscriptions = [topic.subscribe() for _ in range(3)]
|
||||
new_message = b"message-visible-to-every-subscriber"
|
||||
|
||||
try:
|
||||
for subscription in subscriptions:
|
||||
self._start_subscription(subscription)
|
||||
|
||||
topic.publish(new_message)
|
||||
|
||||
for subscription in subscriptions:
|
||||
assert self._receive_message(subscription) == new_message
|
||||
assert subscription.receive(timeout=0.1) is None
|
||||
finally:
|
||||
for subscription in subscriptions:
|
||||
subscription.close()
|
||||
|
||||
def test_each_subscription_only_receives_messages_published_after_it_starts(
|
||||
self,
|
||||
broadcast_channel: BroadcastChannel,
|
||||
) -> None:
|
||||
"""A late subscription should not replay messages that existed before its listener started."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
first_subscription = topic.subscribe()
|
||||
second_subscription = topic.subscribe()
|
||||
message_before_any_subscription = b"before-any-subscription"
|
||||
message_after_first_subscription = b"after-first-subscription"
|
||||
message_after_second_subscription = b"after-second-subscription"
|
||||
|
||||
try:
|
||||
topic.publish(message_before_any_subscription)
|
||||
|
||||
self._start_subscription(first_subscription)
|
||||
topic.publish(message_after_first_subscription)
|
||||
|
||||
assert self._receive_message(first_subscription) == message_after_first_subscription
|
||||
assert first_subscription.receive(timeout=0.1) is None
|
||||
|
||||
self._start_subscription(second_subscription)
|
||||
topic.publish(message_after_second_subscription)
|
||||
|
||||
assert self._receive_message(first_subscription) == message_after_second_subscription
|
||||
assert self._receive_message(second_subscription) == message_after_second_subscription
|
||||
assert first_subscription.receive(timeout=0.1) is None
|
||||
assert second_subscription.receive(timeout=0.1) is None
|
||||
finally:
|
||||
first_subscription.close()
|
||||
second_subscription.close()
|
||||
|
||||
def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Messages from different topics should remain isolated."""
|
||||
topic1 = broadcast_channel.topic(self._get_test_topic_name())
|
||||
topic2 = broadcast_channel.topic(self._get_test_topic_name())
|
||||
message1 = b"message-for-topic-1"
|
||||
message2 = b"message-for-topic-2"
|
||||
|
||||
def consume_single_message(topic: Topic) -> bytes:
|
||||
subscription = topic.subscribe()
|
||||
try:
|
||||
self._start_subscription(subscription)
|
||||
return self._receive_message(subscription)
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
consumer1_future = executor.submit(consume_single_message, topic1)
|
||||
consumer2_future = executor.submit(consume_single_message, topic2)
|
||||
time.sleep(0.1)
|
||||
topic1.publish(message1)
|
||||
topic2.publish(message2)
|
||||
|
||||
assert consumer1_future.result(timeout=5.0) == message1
|
||||
assert consumer2_future.result(timeout=5.0) == message2
|
||||
|
||||
def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Concurrent producers should not lose messages for a live subscription."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscription = topic.subscribe()
|
||||
producer_count = 4
|
||||
messages_per_producer = 4
|
||||
expected_total = producer_count * messages_per_producer
|
||||
consumer_ready = threading.Event()
|
||||
|
||||
def produce_messages(producer_idx: int) -> set[bytes]:
|
||||
producer = topic.as_producer()
|
||||
produced: set[bytes] = set()
|
||||
for message_idx in range(messages_per_producer):
|
||||
payload = f"producer-{producer_idx}-message-{message_idx}".encode()
|
||||
produced.add(payload)
|
||||
producer.publish(payload)
|
||||
time.sleep(0.001)
|
||||
return produced
|
||||
|
||||
def consume_messages() -> set[bytes]:
|
||||
received: set[bytes] = set()
|
||||
try:
|
||||
self._start_subscription(subscription)
|
||||
consumer_ready.set()
|
||||
while len(received) < expected_total:
|
||||
message = subscription.receive(timeout=0.2)
|
||||
if message is not None:
|
||||
received.add(message)
|
||||
return received
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
|
||||
consumer_future = executor.submit(consume_messages)
|
||||
assert consumer_ready.wait(timeout=2.0)
|
||||
|
||||
producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)]
|
||||
expected_messages: set[bytes] = set()
|
||||
for future in as_completed(producer_futures, timeout=10.0):
|
||||
expected_messages.update(future.result())
|
||||
|
||||
assert consumer_future.result(timeout=10.0) == expected_messages
|
||||
|
||||
def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Calling receive on a closed subscription should raise SubscriptionClosedError."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscription = topic.subscribe()
|
||||
|
||||
self._start_subscription(subscription)
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive(timeout=0.1)
|
||||
|
|
@ -1,7 +1,11 @@
|
|||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.streams_channel import (
|
||||
StreamsBroadcastChannel,
|
||||
StreamsTopic,
|
||||
|
|
@ -22,6 +26,7 @@ class FakeStreamsRedis:
|
|||
self._store: dict[str, list[tuple[str, dict]]] = {}
|
||||
self._next_id: dict[str, int] = {}
|
||||
self._expire_calls: dict[str, int] = {}
|
||||
self._dollar_snapshots: dict[str, int] = {}
|
||||
|
||||
# Publisher API
|
||||
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||
|
|
@ -47,7 +52,9 @@ class FakeStreamsRedis:
|
|||
|
||||
# Find position strictly greater than last_id
|
||||
start_idx = 0
|
||||
if last_id != "0-0":
|
||||
if last_id == "$":
|
||||
start_idx = self._dollar_snapshots.setdefault(key, len(entries))
|
||||
elif last_id != "0-0":
|
||||
for i, (eid, _f) in enumerate(entries):
|
||||
if eid == last_id:
|
||||
start_idx = i + 1
|
||||
|
|
@ -63,6 +70,55 @@ class FakeStreamsRedis:
|
|||
return [(key, batch)]
|
||||
|
||||
|
||||
class FailExpireRedis(FakeStreamsRedis):
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise RuntimeError("expire failed")
|
||||
|
||||
|
||||
class BlockingRedis:
|
||||
def __init__(self) -> None:
|
||||
self._release = threading.Event()
|
||||
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
self._release.wait(timeout=block / 1000.0 if block else None)
|
||||
return []
|
||||
|
||||
def release(self) -> None:
|
||||
self._release.set()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListenPayloadCase:
|
||||
name: str
|
||||
fields: object
|
||||
expected_messages: list[bytes]
|
||||
|
||||
|
||||
def build_listen_payload_cases() -> list[ListenPayloadCase]:
|
||||
return [
|
||||
ListenPayloadCase(
|
||||
name="string_payload_is_encoded",
|
||||
fields={b"data": "hello"},
|
||||
expected_messages=[b"hello"],
|
||||
),
|
||||
ListenPayloadCase(
|
||||
name="bytearray_payload_is_converted",
|
||||
fields={b"data": bytearray(b"world")},
|
||||
expected_messages=[b"world"],
|
||||
),
|
||||
ListenPayloadCase(
|
||||
name="non_dict_fields_are_ignored",
|
||||
fields=[("data", b"ignored")],
|
||||
expected_messages=[],
|
||||
),
|
||||
ListenPayloadCase(
|
||||
name="missing_payload_is_ignored",
|
||||
fields={b"other": b"ignored"},
|
||||
expected_messages=[],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis() -> FakeStreamsRedis:
|
||||
return FakeStreamsRedis()
|
||||
|
|
@ -94,21 +150,37 @@ class TestStreamsBroadcastChannel:
|
|||
# Expire called after publish
|
||||
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
|
||||
|
||||
def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("producer-subscriber")
|
||||
|
||||
assert topic.as_producer() is topic
|
||||
assert topic.as_subscriber() is topic
|
||||
|
||||
def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture):
|
||||
channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60)
|
||||
topic = channel.topic("expire-warning")
|
||||
|
||||
topic.publish(b"payload")
|
||||
|
||||
assert "Failed to set expire for stream key" in caplog.text
|
||||
|
||||
|
||||
class TestStreamsSubscription:
|
||||
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
|
||||
def test_subscribe_only_receives_messages_published_after_subscription_starts(
|
||||
self,
|
||||
streams_channel: StreamsBroadcastChannel,
|
||||
):
|
||||
topic = streams_channel.topic("gamma")
|
||||
# Pre-publish events before subscribing (late subscriber)
|
||||
topic.publish(b"e1")
|
||||
topic.publish(b"e2")
|
||||
topic.publish(b"before-subscribe")
|
||||
|
||||
sub = topic.subscribe()
|
||||
assert isinstance(sub, _StreamsSubscription)
|
||||
|
||||
received: list[bytes] = []
|
||||
with sub:
|
||||
# Give listener thread a moment to xread
|
||||
time.sleep(0.05)
|
||||
assert sub.receive(timeout=0.05) is None
|
||||
topic.publish(b"after-subscribe-1")
|
||||
topic.publish(b"after-subscribe-2")
|
||||
# Drain using receive() to avoid indefinite iteration in tests
|
||||
for _ in range(5):
|
||||
msg = sub.receive(timeout=0.1)
|
||||
|
|
@ -116,7 +188,7 @@ class TestStreamsSubscription:
|
|||
break
|
||||
received.append(msg)
|
||||
|
||||
assert received == [b"e1", b"e2"]
|
||||
assert received == [b"after-subscribe-1", b"after-subscribe-2"]
|
||||
|
||||
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("delta")
|
||||
|
|
@ -132,8 +204,6 @@ class TestStreamsSubscription:
|
|||
# Listener running; now close and ensure no crash
|
||||
sub.close()
|
||||
# After close, receive should raise SubscriptionClosedError
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
sub.receive()
|
||||
|
||||
|
|
@ -143,3 +213,141 @@ class TestStreamsSubscription:
|
|||
topic.publish(b"payload")
|
||||
# No expire recorded when retention is disabled
|
||||
assert fake_redis._expire_calls.get("stream:zeta") is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("case"),
|
||||
build_listen_payload_cases(),
|
||||
ids=lambda case: cast(ListenPayloadCase, case).name,
|
||||
)
|
||||
def test_listener_normalizes_supported_payloads_and_ignores_unsupported_shapes(self, case: ListenPayloadCase):
|
||||
class OneShotRedis:
|
||||
def __init__(self, fields: object) -> None:
|
||||
self._fields = fields
|
||||
self._calls = 0
|
||||
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
key = next(iter(streams))
|
||||
return [(key, [("1-0", self._fields)])]
|
||||
subscription._closed.set()
|
||||
return []
|
||||
|
||||
subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape")
|
||||
subscription._listen()
|
||||
|
||||
received: list[bytes] = []
|
||||
while not subscription._queue.empty():
|
||||
item = subscription._queue.get_nowait()
|
||||
if item is subscription._SENTINEL:
|
||||
break
|
||||
received.append(bytes(item))
|
||||
|
||||
assert received == case.expected_messages
|
||||
assert subscription._last_id == "1-0"
|
||||
|
||||
def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("iter")
|
||||
subscription = topic.subscribe()
|
||||
iterator = iter(subscription)
|
||||
|
||||
def publish_later() -> None:
|
||||
time.sleep(0.05)
|
||||
topic.publish(b"iter-message")
|
||||
|
||||
publisher = threading.Thread(target=publish_later, daemon=True)
|
||||
publisher.start()
|
||||
|
||||
assert next(iterator) == b"iter-message"
|
||||
|
||||
subscription.close()
|
||||
publisher.join(timeout=1)
|
||||
with pytest.raises(StopIteration):
|
||||
next(iterator)
|
||||
|
||||
def test_receive_with_none_timeout_blocks_until_message_arrives(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("blocking")
|
||||
subscription = topic.subscribe()
|
||||
|
||||
def publish_later() -> None:
|
||||
time.sleep(0.05)
|
||||
topic.publish(b"blocking-message")
|
||||
|
||||
publisher = threading.Thread(target=publish_later, daemon=True)
|
||||
publisher.start()
|
||||
|
||||
try:
|
||||
assert subscription.receive(timeout=None) == b"blocking-message"
|
||||
finally:
|
||||
subscription.close()
|
||||
publisher.join(timeout=1)
|
||||
|
||||
def test_receive_raises_when_queue_contains_close_sentinel(self):
|
||||
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:sentinel")
|
||||
subscription._listener = threading.current_thread()
|
||||
subscription._queue.put_nowait(subscription._SENTINEL)
|
||||
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive(timeout=0.01)
|
||||
|
||||
def test_close_before_listener_starts_is_a_noop(self):
|
||||
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:not-started")
|
||||
|
||||
subscription.close()
|
||||
|
||||
assert subscription._listener is None
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive(timeout=0.01)
|
||||
|
||||
def test_start_if_needed_returns_immediately_for_closed_subscription(self):
|
||||
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed")
|
||||
subscription._closed.set()
|
||||
|
||||
subscription._start_if_needed()
|
||||
|
||||
assert subscription._listener is None
|
||||
|
||||
def test_iterator_skips_none_results_and_keeps_polling(self):
|
||||
subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:iterator-none")
|
||||
items = iter([None, b"event"])
|
||||
|
||||
subscription._start_if_needed = lambda: None # type: ignore[method-assign]
|
||||
|
||||
def fake_receive(timeout: float | None = 0.1) -> bytes | None:
|
||||
value = next(items)
|
||||
if value is not None:
|
||||
subscription._closed.set()
|
||||
return value
|
||||
|
||||
subscription.receive = fake_receive # type: ignore[method-assign]
|
||||
|
||||
assert next(iter(subscription)) == b"event"
|
||||
|
||||
def test_close_logs_warning_when_listener_does_not_stop_in_time(
|
||||
self,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
):
|
||||
blocking_redis = BlockingRedis()
|
||||
subscription = _StreamsSubscription(blocking_redis, "stream:slow-close")
|
||||
|
||||
subscription._start_if_needed()
|
||||
listener = subscription._listener
|
||||
assert listener is not None
|
||||
|
||||
original_join = listener.join
|
||||
original_is_alive = listener.is_alive
|
||||
|
||||
def delayed_join(timeout: float | None = None) -> None:
|
||||
original_join(0.01)
|
||||
|
||||
listener.join = delayed_join # type: ignore[method-assign]
|
||||
listener.is_alive = lambda: True # type: ignore[method-assign]
|
||||
|
||||
try:
|
||||
subscription.close()
|
||||
assert "did not stop within timeout" in caplog.text
|
||||
finally:
|
||||
listener.join = original_join # type: ignore[method-assign]
|
||||
listener.is_alive = original_is_alive # type: ignore[method-assign]
|
||||
blocking_redis.release()
|
||||
original_join(timeout=1)
|
||||
|
|
|
|||
Loading…
Reference in New Issue