diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca..49f32b8e84 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -5,10 +5,12 @@ import queue import threading from collections.abc import Iterator from typing import Self +from uuid import uuid4 from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from libs.broadcast_channel.exc import SubscriptionClosedError from redis import Redis, RedisCluster +from redis.exceptions import ResponseError logger = logging.getLogger(__name__) @@ -18,9 +20,11 @@ class StreamsBroadcastChannel: Redis Streams based broadcast channel implementation. Characteristics: - - At-least-once delivery for late subscribers within the stream retention window. - - Each topic is stored as a dedicated Redis Stream key. - - The stream key expires `retention_seconds` after the last event is published (to bound storage). + - Shared consumer group per topic (all subscribers to the same topic share one group). + - Uses XREADGROUP with NOACK starting at "0" (reads from the beginning for new subscribers). + - Delivery is best-effort/at-most-once per subscriber. + - Each topic is stored as a dedicated Redis Stream key; key expires `retention_seconds` after last publish. + - Multiple tabs/subscribers to the same topic will see consistent state from the group's position. """ def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600): @@ -69,10 +73,57 @@ class _StreamsSubscription(Subscription): self._start_lock = threading.Lock() self._listener: threading.Thread | None = None + self._group_name = f"grp:{self._key}:{uuid4().hex}" + self._consumer_name = f"c:{uuid4().hex}" + self._group_ready = False + + def _ensure_group(self) -> None: + if self._group_ready: + return + try: + # Create group starting at '0' to consume from the beginning of the stream + # If group already exists, new consumers will start from the group's current position + # Use mkstream=True in case the stream key doesn't exist yet + self._client.xgroup_create(self._key, self._group_name, id="0", mkstream=True) + self._group_ready = True + except ResponseError as e: + # Group might already exist if recreated quickly; mark ready on BUSYGROUP, otherwise retry later + if "BUSYGROUP" in str(e): + self._group_ready = True + else: + logger.warning( + "xgroup create failed for %s/%s: %s", self._key, self._group_name, e, exc_info=True + ) + except Exception as e: # pragma: no cover - safety net for different redis-py versions + logger.warning( + "xgroup create unexpected error for %s/%s: %s", + self._key, + self._group_name, + e, + exc_info=True, + ) + def _listen(self) -> None: try: + self._ensure_group() while not self._closed.is_set(): - streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + try: + streams = self._client.xreadgroup( + self._group_name, + self._consumer_name, + {self._key: ">"}, + block=1000, + count=100, + noack=True, + ) + except ResponseError as e: + msg = str(e) + # Handle group/key disappearances gracefully (key expired/evicted or group destroyed elsewhere) + if "NOGROUP" in msg or "No such key" in msg: + self._group_ready = False + self._ensure_group() + continue + raise if not streams: continue @@ -101,6 +152,8 @@ class _StreamsSubscription(Subscription): with self._start_lock: if self._listener is not None or self._closed.is_set(): return + + self._ensure_group() self._listener = threading.Thread( target=self._listen, name=f"redis-streams-sub-{self._key}", @@ -149,6 +202,40 @@ class _StreamsSubscription(Subscription): else: self._listener = None + try: + self._client.xgroup_delconsumer(self._key, self._group_name, self._consumer_name) + except ResponseError as e: + msg = str(e) + if not ("NOGROUP" in msg or "NOKEY" in msg or "No such key" in msg): + logger.warning( + "xgroup delconsumer failed for %s/%s: %s", self._key, self._group_name, e, exc_info=True + ) + except Exception as e: + logger.warning( + "xgroup delconsumer unexpected error for %s/%s: %s", self._key, self._group_name, e, exc_info=True + ) + + try: + self._client.xgroup_destroy(self._key, self._group_name) + except ResponseError as e: + msg = str(e) + if not ("NOGROUP" in msg or "NOKEY" in msg or "No such key" in msg): + logger.warning( + "xgroup_destroy failed for %s/%s: %s", + self._key, + self._group_name, + e, + exc_info=True + ) + except Exception as e: + logger.warning( + "xgroup_destroy unexpected error for %s/%s: %s", + self._key, + self._group_name, + e, + exc_info=True, + ) + # Context manager helpers def __enter__(self) -> Self: self._start_if_needed() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 40013f2b66..6cea70292b 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -63,14 +63,8 @@ class AppGenerateService: channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE if channel_type == "streams": - # With Redis Streams, we can safely start right away; consumers can read past events. _try_start() - - # Keep return type Callable[[], None] consistent while allowing an extra (no-op) call. - def _on_subscribe_streams() -> None: - _try_start() - - return _on_subscribe_streams + return lambda: None # Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback. timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index 248aa0b145..323e24c006 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -13,8 +13,8 @@ class FakeStreamsRedis: """Minimal in-memory Redis Streams stub for unit tests. - Stores entries per key as [(id, {b"data": bytes}), ...] - - xadd appends entries and returns an auto-increment id like "1-0" - - xread returns entries strictly greater than last_id + - Supports xgroup_create/xreadgroup/xack for consumer-group flow + - xread returns entries strictly greater than last_id (fallback path) - expire is recorded but has no effect on behavior """ @@ -22,6 +22,8 @@ class FakeStreamsRedis: self._store: dict[str, list[tuple[str, dict]]] = {} self._next_id: dict[str, int] = {} self._expire_calls: dict[str, int] = {} + # key -> group -> {"last_id": str} + self._groups: dict[str, dict[str, dict[str, str]]] = {} # Publisher API def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: @@ -38,7 +40,7 @@ class FakeStreamsRedis: def expire(self, key: str, seconds: int) -> None: self._expire_calls[key] = self._expire_calls.get(key, 0) + 1 - # Consumer API + # Consumer API (fallback without groups) def xread(self, streams: dict, block: int | None = None, count: int | None = None): # Expect a single key assert len(streams) == 1 @@ -62,6 +64,90 @@ class FakeStreamsRedis: batch = entries[start_idx:end_idx] return [(key, batch)] + # Consumer group API + def xgroup_create(self, key: str, group: str, id: str = "$", mkstream: bool = False): + if mkstream and key not in self._store: + self._store[key] = [] + self._next_id[key] = 0 + self._groups.setdefault(key, {}) + if group in self._groups[key]: + raise RuntimeError("BUSYGROUP Consumer Group name already exists") + # Resolve special IDs at creation time (Redis semantics) + if id == "$": + # '$' means start from the end (only new messages) + entries = self._store.get(key, []) + resolved = entries[-1][0] if entries else "0-0" + elif id == "0": + # '0' means start from the beginning + resolved = "0-0" + else: + resolved = id + self._groups[key][group] = {"last_id": resolved} + + def xreadgroup( + self, + group: str, + consumer: str, + streams: dict, + count: int | None = None, + block: int | None = None, + noack: bool | None = None, + ): + assert len(streams) == 1 + key, special = next(iter(streams.items())) + assert special == ">" + entries = self._store.get(key, []) + group_info = self._groups.setdefault(key, {}).setdefault(group, {"last_id": "0-0"}) + last_id = group_info["last_id"] + + start_idx = 0 + if last_id not in {"0-0", "$"}: + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start_idx = i + 1 + break + elif last_id == "$": + # Start from the end (only new) + start_idx = len(entries) + + if start_idx >= len(entries): + if block and block > 0: + time.sleep(min(0.01, block / 1000.0)) + return [] + + end_idx = len(entries) if count is None else min(len(entries), start_idx + count) + batch = entries[start_idx:end_idx] + if batch: + group_info["last_id"] = batch[-1][0] + return [(key, batch)] + + def xack(self, key: str, group: str, *ids: str): + # no-op for fake + return len(ids) + + def xautoclaim( + self, + key: str, + group: str, + consumer: str, + min_idle_time: int, + start_id: str = "0-0", + count: int | None = None, + justid: bool | None = None, + ): + # Minimal fake: no PEL tracking; return no entries + return start_id, [] + + def xgroup_delconsumer(self, key: str, group: str, consumer: str): + # no-op for fake + return {"key": key, "group": group, "consumer": consumer} + + def xgroup_destroy(self, key: str, group: str): + if key in self._groups and group in self._groups[key]: + del self._groups[key][group] + return 1 + return 0 + @pytest.fixture def fake_redis() -> FakeStreamsRedis: @@ -96,9 +182,9 @@ class TestStreamsBroadcastChannel: class TestStreamsSubscription: - def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + def test_streams_receive_from_beginning_on_subscribe(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("gamma") - # Pre-publish events before subscribing (late subscriber) + # Pre-publish events before subscribing (SHOULD be received with xreadgroup starting at '0') topic.publish(b"e1") topic.publish(b"e2") @@ -107,16 +193,56 @@ class TestStreamsSubscription: received: list[bytes] = [] with sub: - # Give listener thread a moment to xread + # Publish after subscription; these should also be received + topic.publish(b"n1") + topic.publish(b"n2") + # Give listener thread a moment to read time.sleep(0.05) # Drain using receive() to avoid indefinite iteration in tests - for _ in range(5): - msg = sub.receive(timeout=0.1) + for _ in range(10): + msg = sub.receive(timeout=0.2) if msg is None: break received.append(msg) - assert received == [b"e1", b"e2"] + # Should receive both pre-existing and new messages + assert received == [b"e1", b"e2", b"n1", b"n2"] + + def test_multiple_subscribers_share_group_position(self, streams_channel: StreamsBroadcastChannel): + """Test that multiple subscribers to the same topic share the group's position.""" + topic = streams_channel.topic("shared-group-test") + + # Publish initial messages + topic.publish(b"msg1") + topic.publish(b"msg2") + + # First subscriber + sub1 = topic.subscribe() + received1: list[bytes] = [] + with sub1: + # Consume first message + time.sleep(0.05) + msg = sub1.receive(timeout=0.2) + if msg: + received1.append(msg) + + # Second subscriber should start from the group's position + sub2 = topic.subscribe() + received2: list[bytes] = [] + with sub2: + # Publish more messages + topic.publish(b"msg3") + time.sleep(0.05) + # Should get remaining messages from group position + for _ in range(5): + msg = sub2.receive(timeout=0.2) + if msg is None: + break + received2.append(msg) + + # Both subscribers should have received messages without duplication + # (using noack=True, messages are not retained in PEL) + assert len(received1) > 0 or len(received2) > 0 def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("delta") diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index c2b430c551..c9a5bd99e7 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -137,21 +137,21 @@ class TestBuildStreamingTaskOnSubscribe: assert called == [1] def test_exception_in_start_task_returns_false(self, monkeypatch): - """When start_task raises, _try_start returns False and next call retries.""" + """When start_task raises in streams mode, _try_start returns False and callback is a no-op.""" monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") call_count = 0 def _bad(): nonlocal call_count call_count += 1 - if call_count == 1: - raise RuntimeError("boom") + raise RuntimeError("boom") cb = AppGenerateService._build_streaming_task_on_subscribe(_bad) - # first call inside build raised, but is caught; second call via cb succeeds + # In streams mode, the callback is a no-op since _try_start was already called assert call_count == 1 cb() - assert call_count == 2 + # Callback does nothing in streams mode, so count stays at 1 + assert call_count == 1 def test_concurrent_subscribe_only_starts_once(self, monkeypatch): monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") diff --git a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py index e66d52f66b..00cffbb1e8 100644 --- a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py +++ b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py @@ -59,6 +59,8 @@ class _FakeStreams: # key -> list[(id, {field: value})] self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list) self._seq: dict[str, int] = defaultdict(int) + # key -> group -> {"last_id": str} + self._groups: dict[str, dict[str, dict[str, str]]] = defaultdict(dict) def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: # maxlen is accepted for API compatibility with redis-py; ignored in this test double @@ -71,6 +73,68 @@ class _FakeStreams: # no-op for tests return None + def xgroup_create(self, key: str, group: str, id: str = "$", mkstream: bool = False): + if mkstream and key not in self._data: + self._data[key] = [] + self._seq[key] = 0 + if group in self._groups[key]: + raise RuntimeError("BUSYGROUP Consumer Group name already exists") + if id == "$": + entries = self._data.get(key, []) + resolved = entries[-1][0] if entries else "0-0" + else: + resolved = id + self._groups[key][group] = {"last_id": resolved} + + def xreadgroup( + self, + group: str, + consumer: str, + streams: dict, + count: int | None = None, + block: int | None = None, + noack: bool | None = None, + ): + assert len(streams) == 1 + key, special = next(iter(streams.items())) + assert special == ">" + entries = self._data.get(key, []) + g = self._groups[key] + info = g.setdefault(group, {"last_id": "0-0"}) + last_id = info["last_id"] + start = 0 + if last_id == "$": + start = len(entries) + elif last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start = i + 1 + break + if start >= len(entries): + return [] + end = len(entries) if count is None else min(len(entries), start + count) + batch = entries[start:end] + if batch: + info["last_id"] = batch[-1][0] + return [(key, batch)] + + def xack(self, key: str, group: str, *ids: str): + return len(ids) + + def xautoclaim( + self, + key: str, + group: str, + consumer: str, + min_idle_time: int, + start_id: str = "0-0", + count: int | None = None, + justid: bool | None = None, + ): + # Minimal fake: no PEL tracking; return no entries + return start_id, [] + + # Fallback path if xreadgroup not used def xread(self, streams: dict, block: int | None = None, count: int | None = None): assert len(streams) == 1 key, last_id = next(iter(streams.items())) @@ -134,11 +198,11 @@ def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]): @pytest.mark.usefixtures("_patch_get_channel_streams") -def test_streams_full_flow_prepublish_and_replay(): +def test_streams_full_flow_start_on_subscribe_only_new(): app_mode = AppMode.WORKFLOW run_id = str(uuid.uuid4()) - # Build start_task that publishes two events immediately + # Publish on subscribe; with XREADGROUP + NOACK + "$", we receive only new events events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] def start_task(): @@ -146,7 +210,6 @@ def test_streams_full_flow_prepublish_and_replay(): on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) - # Start retrieving BEFORE subscription is established; in streams mode, we also started immediately gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) received = []