mirror of https://github.com/langgenius/dify.git
feat: change xread to xreadgroup
This commit is contained in:
parent
d956b919a0
commit
94815c4e8b
|
|
@ -5,10 +5,12 @@ import queue
|
|||
import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Self
|
||||
from uuid import uuid4
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -18,9 +20,11 @@ class StreamsBroadcastChannel:
|
|||
Redis Streams based broadcast channel implementation.
|
||||
|
||||
Characteristics:
|
||||
- At-least-once delivery for late subscribers within the stream retention window.
|
||||
- Each topic is stored as a dedicated Redis Stream key.
|
||||
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
|
||||
- Shared consumer group per topic (all subscribers to the same topic share one group).
|
||||
- Uses XREADGROUP with NOACK starting at "0" (reads from the beginning for new subscribers).
|
||||
- Delivery is best-effort/at-most-once per subscriber.
|
||||
- Each topic is stored as a dedicated Redis Stream key; key expires `retention_seconds` after last publish.
|
||||
- Multiple tabs/subscribers to the same topic will see consistent state from the group's position.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
|
||||
|
|
@ -69,10 +73,57 @@ class _StreamsSubscription(Subscription):
|
|||
self._start_lock = threading.Lock()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
self._group_name = f"grp:{self._key}:{uuid4().hex}"
|
||||
self._consumer_name = f"c:{uuid4().hex}"
|
||||
self._group_ready = False
|
||||
|
||||
def _ensure_group(self) -> None:
|
||||
if self._group_ready:
|
||||
return
|
||||
try:
|
||||
# Create group starting at '0' to consume from the beginning of the stream
|
||||
# If group already exists, new consumers will start from the group's current position
|
||||
# Use mkstream=True in case the stream key doesn't exist yet
|
||||
self._client.xgroup_create(self._key, self._group_name, id="0", mkstream=True)
|
||||
self._group_ready = True
|
||||
except ResponseError as e:
|
||||
# Group might already exist if recreated quickly; mark ready on BUSYGROUP, otherwise retry later
|
||||
if "BUSYGROUP" in str(e):
|
||||
self._group_ready = True
|
||||
else:
|
||||
logger.warning(
|
||||
"xgroup create failed for %s/%s: %s", self._key, self._group_name, e, exc_info=True
|
||||
)
|
||||
except Exception as e: # pragma: no cover - safety net for different redis-py versions
|
||||
logger.warning(
|
||||
"xgroup create unexpected error for %s/%s: %s",
|
||||
self._key,
|
||||
self._group_name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _listen(self) -> None:
|
||||
try:
|
||||
self._ensure_group()
|
||||
while not self._closed.is_set():
|
||||
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
|
||||
try:
|
||||
streams = self._client.xreadgroup(
|
||||
self._group_name,
|
||||
self._consumer_name,
|
||||
{self._key: ">"},
|
||||
block=1000,
|
||||
count=100,
|
||||
noack=True,
|
||||
)
|
||||
except ResponseError as e:
|
||||
msg = str(e)
|
||||
# Handle group/key disappearances gracefully (key expired/evicted or group destroyed elsewhere)
|
||||
if "NOGROUP" in msg or "No such key" in msg:
|
||||
self._group_ready = False
|
||||
self._ensure_group()
|
||||
continue
|
||||
raise
|
||||
|
||||
if not streams:
|
||||
continue
|
||||
|
|
@ -101,6 +152,8 @@ class _StreamsSubscription(Subscription):
|
|||
with self._start_lock:
|
||||
if self._listener is not None or self._closed.is_set():
|
||||
return
|
||||
|
||||
self._ensure_group()
|
||||
self._listener = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-streams-sub-{self._key}",
|
||||
|
|
@ -149,6 +202,40 @@ class _StreamsSubscription(Subscription):
|
|||
else:
|
||||
self._listener = None
|
||||
|
||||
try:
|
||||
self._client.xgroup_delconsumer(self._key, self._group_name, self._consumer_name)
|
||||
except ResponseError as e:
|
||||
msg = str(e)
|
||||
if not ("NOGROUP" in msg or "NOKEY" in msg or "No such key" in msg):
|
||||
logger.warning(
|
||||
"xgroup delconsumer failed for %s/%s: %s", self._key, self._group_name, e, exc_info=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"xgroup delconsumer unexpected error for %s/%s: %s", self._key, self._group_name, e, exc_info=True
|
||||
)
|
||||
|
||||
try:
|
||||
self._client.xgroup_destroy(self._key, self._group_name)
|
||||
except ResponseError as e:
|
||||
msg = str(e)
|
||||
if not ("NOGROUP" in msg or "NOKEY" in msg or "No such key" in msg):
|
||||
logger.warning(
|
||||
"xgroup_destroy failed for %s/%s: %s",
|
||||
self._key,
|
||||
self._group_name,
|
||||
e,
|
||||
exc_info=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"xgroup_destroy unexpected error for %s/%s: %s",
|
||||
self._key,
|
||||
self._group_name,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Context manager helpers
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
|
|
|
|||
|
|
@ -63,14 +63,8 @@ class AppGenerateService:
|
|||
|
||||
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
|
||||
if channel_type == "streams":
|
||||
# With Redis Streams, we can safely start right away; consumers can read past events.
|
||||
_try_start()
|
||||
|
||||
# Keep return type Callable[[], None] consistent while allowing an extra (no-op) call.
|
||||
def _on_subscribe_streams() -> None:
|
||||
_try_start()
|
||||
|
||||
return _on_subscribe_streams
|
||||
return lambda: None
|
||||
|
||||
# Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback.
|
||||
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ class FakeStreamsRedis:
|
|||
"""Minimal in-memory Redis Streams stub for unit tests.
|
||||
|
||||
- Stores entries per key as [(id, {b"data": bytes}), ...]
|
||||
- xadd appends entries and returns an auto-increment id like "1-0"
|
||||
- xread returns entries strictly greater than last_id
|
||||
- Supports xgroup_create/xreadgroup/xack for consumer-group flow
|
||||
- xread returns entries strictly greater than last_id (fallback path)
|
||||
- expire is recorded but has no effect on behavior
|
||||
"""
|
||||
|
||||
|
|
@ -22,6 +22,8 @@ class FakeStreamsRedis:
|
|||
self._store: dict[str, list[tuple[str, dict]]] = {}
|
||||
self._next_id: dict[str, int] = {}
|
||||
self._expire_calls: dict[str, int] = {}
|
||||
# key -> group -> {"last_id": str}
|
||||
self._groups: dict[str, dict[str, dict[str, str]]] = {}
|
||||
|
||||
# Publisher API
|
||||
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||
|
|
@ -38,7 +40,7 @@ class FakeStreamsRedis:
|
|||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
|
||||
|
||||
# Consumer API
|
||||
# Consumer API (fallback without groups)
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
# Expect a single key
|
||||
assert len(streams) == 1
|
||||
|
|
@ -62,6 +64,90 @@ class FakeStreamsRedis:
|
|||
batch = entries[start_idx:end_idx]
|
||||
return [(key, batch)]
|
||||
|
||||
# Consumer group API
|
||||
def xgroup_create(self, key: str, group: str, id: str = "$", mkstream: bool = False):
|
||||
if mkstream and key not in self._store:
|
||||
self._store[key] = []
|
||||
self._next_id[key] = 0
|
||||
self._groups.setdefault(key, {})
|
||||
if group in self._groups[key]:
|
||||
raise RuntimeError("BUSYGROUP Consumer Group name already exists")
|
||||
# Resolve special IDs at creation time (Redis semantics)
|
||||
if id == "$":
|
||||
# '$' means start from the end (only new messages)
|
||||
entries = self._store.get(key, [])
|
||||
resolved = entries[-1][0] if entries else "0-0"
|
||||
elif id == "0":
|
||||
# '0' means start from the beginning
|
||||
resolved = "0-0"
|
||||
else:
|
||||
resolved = id
|
||||
self._groups[key][group] = {"last_id": resolved}
|
||||
|
||||
def xreadgroup(
|
||||
self,
|
||||
group: str,
|
||||
consumer: str,
|
||||
streams: dict,
|
||||
count: int | None = None,
|
||||
block: int | None = None,
|
||||
noack: bool | None = None,
|
||||
):
|
||||
assert len(streams) == 1
|
||||
key, special = next(iter(streams.items()))
|
||||
assert special == ">"
|
||||
entries = self._store.get(key, [])
|
||||
group_info = self._groups.setdefault(key, {}).setdefault(group, {"last_id": "0-0"})
|
||||
last_id = group_info["last_id"]
|
||||
|
||||
start_idx = 0
|
||||
if last_id not in {"0-0", "$"}:
|
||||
for i, (eid, _f) in enumerate(entries):
|
||||
if eid == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
elif last_id == "$":
|
||||
# Start from the end (only new)
|
||||
start_idx = len(entries)
|
||||
|
||||
if start_idx >= len(entries):
|
||||
if block and block > 0:
|
||||
time.sleep(min(0.01, block / 1000.0))
|
||||
return []
|
||||
|
||||
end_idx = len(entries) if count is None else min(len(entries), start_idx + count)
|
||||
batch = entries[start_idx:end_idx]
|
||||
if batch:
|
||||
group_info["last_id"] = batch[-1][0]
|
||||
return [(key, batch)]
|
||||
|
||||
def xack(self, key: str, group: str, *ids: str):
|
||||
# no-op for fake
|
||||
return len(ids)
|
||||
|
||||
def xautoclaim(
|
||||
self,
|
||||
key: str,
|
||||
group: str,
|
||||
consumer: str,
|
||||
min_idle_time: int,
|
||||
start_id: str = "0-0",
|
||||
count: int | None = None,
|
||||
justid: bool | None = None,
|
||||
):
|
||||
# Minimal fake: no PEL tracking; return no entries
|
||||
return start_id, []
|
||||
|
||||
def xgroup_delconsumer(self, key: str, group: str, consumer: str):
|
||||
# no-op for fake
|
||||
return {"key": key, "group": group, "consumer": consumer}
|
||||
|
||||
def xgroup_destroy(self, key: str, group: str):
|
||||
if key in self._groups and group in self._groups[key]:
|
||||
del self._groups[key][group]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis() -> FakeStreamsRedis:
|
||||
|
|
@ -96,9 +182,9 @@ class TestStreamsBroadcastChannel:
|
|||
|
||||
|
||||
class TestStreamsSubscription:
|
||||
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
|
||||
def test_streams_receive_from_beginning_on_subscribe(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("gamma")
|
||||
# Pre-publish events before subscribing (late subscriber)
|
||||
# Pre-publish events before subscribing (SHOULD be received with xreadgroup starting at '0')
|
||||
topic.publish(b"e1")
|
||||
topic.publish(b"e2")
|
||||
|
||||
|
|
@ -107,16 +193,56 @@ class TestStreamsSubscription:
|
|||
|
||||
received: list[bytes] = []
|
||||
with sub:
|
||||
# Give listener thread a moment to xread
|
||||
# Publish after subscription; these should also be received
|
||||
topic.publish(b"n1")
|
||||
topic.publish(b"n2")
|
||||
# Give listener thread a moment to read
|
||||
time.sleep(0.05)
|
||||
# Drain using receive() to avoid indefinite iteration in tests
|
||||
for _ in range(5):
|
||||
msg = sub.receive(timeout=0.1)
|
||||
for _ in range(10):
|
||||
msg = sub.receive(timeout=0.2)
|
||||
if msg is None:
|
||||
break
|
||||
received.append(msg)
|
||||
|
||||
assert received == [b"e1", b"e2"]
|
||||
# Should receive both pre-existing and new messages
|
||||
assert received == [b"e1", b"e2", b"n1", b"n2"]
|
||||
|
||||
def test_multiple_subscribers_share_group_position(self, streams_channel: StreamsBroadcastChannel):
|
||||
"""Test that multiple subscribers to the same topic share the group's position."""
|
||||
topic = streams_channel.topic("shared-group-test")
|
||||
|
||||
# Publish initial messages
|
||||
topic.publish(b"msg1")
|
||||
topic.publish(b"msg2")
|
||||
|
||||
# First subscriber
|
||||
sub1 = topic.subscribe()
|
||||
received1: list[bytes] = []
|
||||
with sub1:
|
||||
# Consume first message
|
||||
time.sleep(0.05)
|
||||
msg = sub1.receive(timeout=0.2)
|
||||
if msg:
|
||||
received1.append(msg)
|
||||
|
||||
# Second subscriber should start from the group's position
|
||||
sub2 = topic.subscribe()
|
||||
received2: list[bytes] = []
|
||||
with sub2:
|
||||
# Publish more messages
|
||||
topic.publish(b"msg3")
|
||||
time.sleep(0.05)
|
||||
# Should get remaining messages from group position
|
||||
for _ in range(5):
|
||||
msg = sub2.receive(timeout=0.2)
|
||||
if msg is None:
|
||||
break
|
||||
received2.append(msg)
|
||||
|
||||
# Both subscribers should have received messages without duplication
|
||||
# (using noack=True, messages are not retained in PEL)
|
||||
assert len(received1) > 0 or len(received2) > 0
|
||||
|
||||
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("delta")
|
||||
|
|
|
|||
|
|
@ -137,21 +137,21 @@ class TestBuildStreamingTaskOnSubscribe:
|
|||
assert called == [1]
|
||||
|
||||
def test_exception_in_start_task_returns_false(self, monkeypatch):
|
||||
"""When start_task raises, _try_start returns False and next call retries."""
|
||||
"""When start_task raises in streams mode, _try_start returns False and callback is a no-op."""
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams")
|
||||
call_count = 0
|
||||
|
||||
def _bad():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("boom")
|
||||
raise RuntimeError("boom")
|
||||
|
||||
cb = AppGenerateService._build_streaming_task_on_subscribe(_bad)
|
||||
# first call inside build raised, but is caught; second call via cb succeeds
|
||||
# In streams mode, the callback is a no-op since _try_start was already called
|
||||
assert call_count == 1
|
||||
cb()
|
||||
assert call_count == 2
|
||||
# Callback does nothing in streams mode, so count stays at 1
|
||||
assert call_count == 1
|
||||
|
||||
def test_concurrent_subscribe_only_starts_once(self, monkeypatch):
|
||||
monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ class _FakeStreams:
|
|||
# key -> list[(id, {field: value})]
|
||||
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
|
||||
self._seq: dict[str, int] = defaultdict(int)
|
||||
# key -> group -> {"last_id": str}
|
||||
self._groups: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
|
||||
|
||||
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
|
||||
|
|
@ -71,6 +73,68 @@ class _FakeStreams:
|
|||
# no-op for tests
|
||||
return None
|
||||
|
||||
def xgroup_create(self, key: str, group: str, id: str = "$", mkstream: bool = False):
|
||||
if mkstream and key not in self._data:
|
||||
self._data[key] = []
|
||||
self._seq[key] = 0
|
||||
if group in self._groups[key]:
|
||||
raise RuntimeError("BUSYGROUP Consumer Group name already exists")
|
||||
if id == "$":
|
||||
entries = self._data.get(key, [])
|
||||
resolved = entries[-1][0] if entries else "0-0"
|
||||
else:
|
||||
resolved = id
|
||||
self._groups[key][group] = {"last_id": resolved}
|
||||
|
||||
def xreadgroup(
|
||||
self,
|
||||
group: str,
|
||||
consumer: str,
|
||||
streams: dict,
|
||||
count: int | None = None,
|
||||
block: int | None = None,
|
||||
noack: bool | None = None,
|
||||
):
|
||||
assert len(streams) == 1
|
||||
key, special = next(iter(streams.items()))
|
||||
assert special == ">"
|
||||
entries = self._data.get(key, [])
|
||||
g = self._groups[key]
|
||||
info = g.setdefault(group, {"last_id": "0-0"})
|
||||
last_id = info["last_id"]
|
||||
start = 0
|
||||
if last_id == "$":
|
||||
start = len(entries)
|
||||
elif last_id != "0-0":
|
||||
for i, (eid, _f) in enumerate(entries):
|
||||
if eid == last_id:
|
||||
start = i + 1
|
||||
break
|
||||
if start >= len(entries):
|
||||
return []
|
||||
end = len(entries) if count is None else min(len(entries), start + count)
|
||||
batch = entries[start:end]
|
||||
if batch:
|
||||
info["last_id"] = batch[-1][0]
|
||||
return [(key, batch)]
|
||||
|
||||
def xack(self, key: str, group: str, *ids: str):
|
||||
return len(ids)
|
||||
|
||||
def xautoclaim(
|
||||
self,
|
||||
key: str,
|
||||
group: str,
|
||||
consumer: str,
|
||||
min_idle_time: int,
|
||||
start_id: str = "0-0",
|
||||
count: int | None = None,
|
||||
justid: bool | None = None,
|
||||
):
|
||||
# Minimal fake: no PEL tracking; return no entries
|
||||
return start_id, []
|
||||
|
||||
# Fallback path if xreadgroup not used
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
assert len(streams) == 1
|
||||
key, last_id = next(iter(streams.items()))
|
||||
|
|
@ -134,11 +198,11 @@ def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]):
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("_patch_get_channel_streams")
|
||||
def test_streams_full_flow_prepublish_and_replay():
|
||||
def test_streams_full_flow_start_on_subscribe_only_new():
|
||||
app_mode = AppMode.WORKFLOW
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
# Build start_task that publishes two events immediately
|
||||
# Publish on subscribe; with XREADGROUP + NOACK + "$", we receive only new events
|
||||
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
|
||||
|
||||
def start_task():
|
||||
|
|
@ -146,7 +210,6 @@ def test_streams_full_flow_prepublish_and_replay():
|
|||
|
||||
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
|
||||
|
||||
# Start retrieving BEFORE subscription is established; in streams mode, we also started immediately
|
||||
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
|
||||
|
||||
received = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue