feat: change xread to xreadgroup

This commit is contained in:
fatelei 2026-03-17 16:07:26 +08:00
parent d956b919a0
commit 94815c4e8b
No known key found for this signature in database
GPG Key ID: 2F91DA05646F4EED
5 changed files with 298 additions and 28 deletions

View File

@ -5,10 +5,12 @@ import queue
import threading
from collections.abc import Iterator
from typing import Self
from uuid import uuid4
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
from redis.exceptions import ResponseError
logger = logging.getLogger(__name__)
@ -18,9 +20,11 @@ class StreamsBroadcastChannel:
Redis Streams based broadcast channel implementation.
Characteristics:
- At-least-once delivery for late subscribers within the stream retention window.
- Each topic is stored as a dedicated Redis Stream key.
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
- Shared consumer group per topic (all subscribers to the same topic share one group).
- Uses XREADGROUP with NOACK starting at "0" (reads from the beginning for new subscribers).
- Delivery is best-effort/at-most-once per subscriber.
- Each topic is stored as a dedicated Redis Stream key; key expires `retention_seconds` after last publish.
- Multiple tabs/subscribers to the same topic will see consistent state from the group's position.
"""
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
@ -69,10 +73,57 @@ class _StreamsSubscription(Subscription):
self._start_lock = threading.Lock()
self._listener: threading.Thread | None = None
self._group_name = f"grp:{self._key}:{uuid4().hex}"
self._consumer_name = f"c:{uuid4().hex}"
self._group_ready = False
def _ensure_group(self) -> None:
if self._group_ready:
return
try:
# Create group starting at '0' to consume from the beginning of the stream
# If group already exists, new consumers will start from the group's current position
# Use mkstream=True in case the stream key doesn't exist yet
self._client.xgroup_create(self._key, self._group_name, id="0", mkstream=True)
self._group_ready = True
except ResponseError as e:
# Group might already exist if recreated quickly; mark ready on BUSYGROUP, otherwise retry later
if "BUSYGROUP" in str(e):
self._group_ready = True
else:
logger.warning(
"xgroup create failed for %s/%s: %s", self._key, self._group_name, e, exc_info=True
)
except Exception as e: # pragma: no cover - safety net for different redis-py versions
logger.warning(
"xgroup create unexpected error for %s/%s: %s",
self._key,
self._group_name,
e,
exc_info=True,
)
def _listen(self) -> None:
try:
self._ensure_group()
while not self._closed.is_set():
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
try:
streams = self._client.xreadgroup(
self._group_name,
self._consumer_name,
{self._key: ">"},
block=1000,
count=100,
noack=True,
)
except ResponseError as e:
msg = str(e)
# Handle group/key disappearances gracefully (key expired/evicted or group destroyed elsewhere)
if "NOGROUP" in msg or "No such key" in msg:
self._group_ready = False
self._ensure_group()
continue
raise
if not streams:
continue
@ -101,6 +152,8 @@ class _StreamsSubscription(Subscription):
with self._start_lock:
if self._listener is not None or self._closed.is_set():
return
self._ensure_group()
self._listener = threading.Thread(
target=self._listen,
name=f"redis-streams-sub-{self._key}",
@ -149,6 +202,40 @@ class _StreamsSubscription(Subscription):
else:
self._listener = None
try:
self._client.xgroup_delconsumer(self._key, self._group_name, self._consumer_name)
except ResponseError as e:
msg = str(e)
if not ("NOGROUP" in msg or "NOKEY" in msg or "No such key" in msg):
logger.warning(
"xgroup delconsumer failed for %s/%s: %s", self._key, self._group_name, e, exc_info=True
)
except Exception as e:
logger.warning(
"xgroup delconsumer unexpected error for %s/%s: %s", self._key, self._group_name, e, exc_info=True
)
try:
self._client.xgroup_destroy(self._key, self._group_name)
except ResponseError as e:
msg = str(e)
if not ("NOGROUP" in msg or "NOKEY" in msg or "No such key" in msg):
logger.warning(
"xgroup_destroy failed for %s/%s: %s",
self._key,
self._group_name,
e,
exc_info=True
)
except Exception as e:
logger.warning(
"xgroup_destroy unexpected error for %s/%s: %s",
self._key,
self._group_name,
e,
exc_info=True,
)
# Context manager helpers
def __enter__(self) -> Self:
self._start_if_needed()

View File

@ -63,14 +63,8 @@ class AppGenerateService:
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
if channel_type == "streams":
# With Redis Streams, we can safely start right away; consumers can read past events.
_try_start()
# Keep return type Callable[[], None] consistent while allowing an extra (no-op) call.
def _on_subscribe_streams() -> None:
_try_start()
return _on_subscribe_streams
return lambda: None
# Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback.
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)

View File

@ -13,8 +13,8 @@ class FakeStreamsRedis:
"""Minimal in-memory Redis Streams stub for unit tests.
- Stores entries per key as [(id, {b"data": bytes}), ...]
- xadd appends entries and returns an auto-increment id like "1-0"
- xread returns entries strictly greater than last_id
- Supports xgroup_create/xreadgroup/xack for consumer-group flow
- xread returns entries strictly greater than last_id (fallback path)
- expire is recorded but has no effect on behavior
"""
@ -22,6 +22,8 @@ class FakeStreamsRedis:
self._store: dict[str, list[tuple[str, dict]]] = {}
self._next_id: dict[str, int] = {}
self._expire_calls: dict[str, int] = {}
# key -> group -> {"last_id": str}
self._groups: dict[str, dict[str, dict[str, str]]] = {}
# Publisher API
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
@ -38,7 +40,7 @@ class FakeStreamsRedis:
def expire(self, key: str, seconds: int) -> None:
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
# Consumer API
# Consumer API (fallback without groups)
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
# Expect a single key
assert len(streams) == 1
@ -62,6 +64,90 @@ class FakeStreamsRedis:
batch = entries[start_idx:end_idx]
return [(key, batch)]
# Consumer group API
def xgroup_create(self, key: str, group: str, id: str = "$", mkstream: bool = False):
if mkstream and key not in self._store:
self._store[key] = []
self._next_id[key] = 0
self._groups.setdefault(key, {})
if group in self._groups[key]:
raise RuntimeError("BUSYGROUP Consumer Group name already exists")
# Resolve special IDs at creation time (Redis semantics)
if id == "$":
# '$' means start from the end (only new messages)
entries = self._store.get(key, [])
resolved = entries[-1][0] if entries else "0-0"
elif id == "0":
# '0' means start from the beginning
resolved = "0-0"
else:
resolved = id
self._groups[key][group] = {"last_id": resolved}
def xreadgroup(
self,
group: str,
consumer: str,
streams: dict,
count: int | None = None,
block: int | None = None,
noack: bool | None = None,
):
assert len(streams) == 1
key, special = next(iter(streams.items()))
assert special == ">"
entries = self._store.get(key, [])
group_info = self._groups.setdefault(key, {}).setdefault(group, {"last_id": "0-0"})
last_id = group_info["last_id"]
start_idx = 0
if last_id not in {"0-0", "$"}:
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start_idx = i + 1
break
elif last_id == "$":
# Start from the end (only new)
start_idx = len(entries)
if start_idx >= len(entries):
if block and block > 0:
time.sleep(min(0.01, block / 1000.0))
return []
end_idx = len(entries) if count is None else min(len(entries), start_idx + count)
batch = entries[start_idx:end_idx]
if batch:
group_info["last_id"] = batch[-1][0]
return [(key, batch)]
def xack(self, key: str, group: str, *ids: str):
# no-op for fake
return len(ids)
def xautoclaim(
self,
key: str,
group: str,
consumer: str,
min_idle_time: int,
start_id: str = "0-0",
count: int | None = None,
justid: bool | None = None,
):
# Minimal fake: no PEL tracking; return no entries
return start_id, []
def xgroup_delconsumer(self, key: str, group: str, consumer: str):
# no-op for fake
return {"key": key, "group": group, "consumer": consumer}
def xgroup_destroy(self, key: str, group: str):
if key in self._groups and group in self._groups[key]:
del self._groups[key][group]
return 1
return 0
@pytest.fixture
def fake_redis() -> FakeStreamsRedis:
@ -96,9 +182,9 @@ class TestStreamsBroadcastChannel:
class TestStreamsSubscription:
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
def test_streams_receive_from_beginning_on_subscribe(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("gamma")
# Pre-publish events before subscribing (late subscriber)
# Pre-publish events before subscribing (SHOULD be received with xreadgroup starting at '0')
topic.publish(b"e1")
topic.publish(b"e2")
@ -107,16 +193,56 @@ class TestStreamsSubscription:
received: list[bytes] = []
with sub:
# Give listener thread a moment to xread
# Publish after subscription; these should also be received
topic.publish(b"n1")
topic.publish(b"n2")
# Give listener thread a moment to read
time.sleep(0.05)
# Drain using receive() to avoid indefinite iteration in tests
for _ in range(5):
msg = sub.receive(timeout=0.1)
for _ in range(10):
msg = sub.receive(timeout=0.2)
if msg is None:
break
received.append(msg)
assert received == [b"e1", b"e2"]
# Should receive both pre-existing and new messages
assert received == [b"e1", b"e2", b"n1", b"n2"]
def test_multiple_subscribers_share_group_position(self, streams_channel: StreamsBroadcastChannel):
"""Test that multiple subscribers to the same topic share the group's position."""
topic = streams_channel.topic("shared-group-test")
# Publish initial messages
topic.publish(b"msg1")
topic.publish(b"msg2")
# First subscriber
sub1 = topic.subscribe()
received1: list[bytes] = []
with sub1:
# Consume first message
time.sleep(0.05)
msg = sub1.receive(timeout=0.2)
if msg:
received1.append(msg)
# Second subscriber should start from the group's position
sub2 = topic.subscribe()
received2: list[bytes] = []
with sub2:
# Publish more messages
topic.publish(b"msg3")
time.sleep(0.05)
# Should get remaining messages from group position
for _ in range(5):
msg = sub2.receive(timeout=0.2)
if msg is None:
break
received2.append(msg)
# Both subscribers should have received messages without duplication
# (using noack=True, messages are not retained in PEL)
assert len(received1) > 0 or len(received2) > 0
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("delta")

View File

@ -137,21 +137,21 @@ class TestBuildStreamingTaskOnSubscribe:
assert called == [1]
def test_exception_in_start_task_returns_false(self, monkeypatch):
"""When start_task raises, _try_start returns False and next call retries."""
"""When start_task raises in streams mode, _try_start returns False and callback is a no-op."""
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams")
call_count = 0
def _bad():
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("boom")
raise RuntimeError("boom")
cb = AppGenerateService._build_streaming_task_on_subscribe(_bad)
# first call inside build raised, but is caught; second call via cb succeeds
# In streams mode, the callback is a no-op since _try_start was already called
assert call_count == 1
cb()
assert call_count == 2
# Callback does nothing in streams mode, so count stays at 1
assert call_count == 1
def test_concurrent_subscribe_only_starts_once(self, monkeypatch):
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")

View File

@ -59,6 +59,8 @@ class _FakeStreams:
# key -> list[(id, {field: value})]
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
self._seq: dict[str, int] = defaultdict(int)
# key -> group -> {"last_id": str}
self._groups: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
@ -71,6 +73,68 @@ class _FakeStreams:
# no-op for tests
return None
def xgroup_create(self, key: str, group: str, id: str = "$", mkstream: bool = False):
if mkstream and key not in self._data:
self._data[key] = []
self._seq[key] = 0
if group in self._groups[key]:
raise RuntimeError("BUSYGROUP Consumer Group name already exists")
if id == "$":
entries = self._data.get(key, [])
resolved = entries[-1][0] if entries else "0-0"
else:
resolved = id
self._groups[key][group] = {"last_id": resolved}
def xreadgroup(
self,
group: str,
consumer: str,
streams: dict,
count: int | None = None,
block: int | None = None,
noack: bool | None = None,
):
assert len(streams) == 1
key, special = next(iter(streams.items()))
assert special == ">"
entries = self._data.get(key, [])
g = self._groups[key]
info = g.setdefault(group, {"last_id": "0-0"})
last_id = info["last_id"]
start = 0
if last_id == "$":
start = len(entries)
elif last_id != "0-0":
for i, (eid, _f) in enumerate(entries):
if eid == last_id:
start = i + 1
break
if start >= len(entries):
return []
end = len(entries) if count is None else min(len(entries), start + count)
batch = entries[start:end]
if batch:
info["last_id"] = batch[-1][0]
return [(key, batch)]
def xack(self, key: str, group: str, *ids: str):
return len(ids)
def xautoclaim(
self,
key: str,
group: str,
consumer: str,
min_idle_time: int,
start_id: str = "0-0",
count: int | None = None,
justid: bool | None = None,
):
# Minimal fake: no PEL tracking; return no entries
return start_id, []
# Fallback path if xreadgroup not used
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
@ -134,11 +198,11 @@ def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]):
@pytest.mark.usefixtures("_patch_get_channel_streams")
def test_streams_full_flow_prepublish_and_replay():
def test_streams_full_flow_start_on_subscribe_only_new():
app_mode = AppMode.WORKFLOW
run_id = str(uuid.uuid4())
# Build start_task that publishes two events immediately
# Publish on subscribe; with XREADGROUP + NOACK + "$", we receive only new events
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
def start_task():
@ -146,7 +210,6 @@ def test_streams_full_flow_prepublish_and_replay():
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
# Start retrieving BEFORE subscription is established; in streams mode, we also started immediately
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
received = []