mirror of https://github.com/langgenius/dify.git
feat: support redis xstream (#32586)
This commit is contained in:
parent
e14b09d4db
commit
2f4c740d46
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Literal, Protocol
|
from typing import Literal, Protocol
|
||||||
from urllib.parse import quote_plus, urlunparse
|
from urllib.parse import quote_plus, urlunparse
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import AliasChoices, Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,41 +23,56 @@ class RedisConfigDefaultsMixin:
|
||||||
|
|
||||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||||
"""
|
"""
|
||||||
Configuration settings for Redis pub/sub streaming.
|
Configuration settings for event transport between API and workers.
|
||||||
|
|
||||||
|
Supported transports:
|
||||||
|
- pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once)
|
||||||
|
- sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling)
|
||||||
|
- streams: Redis Streams (at-least-once, supports late subscribers)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PUBSUB_REDIS_URL: str | None = Field(
|
PUBSUB_REDIS_URL: str | None = Field(
|
||||||
alias="PUBSUB_REDIS_URL",
|
validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"),
|
||||||
description=(
|
description=(
|
||||||
"Redis connection URL for pub/sub streaming events between API "
|
"Redis connection URL for streaming events between API and celery worker; "
|
||||||
"and celery worker, defaults to url constructed from "
|
"defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL."
|
||||||
"`REDIS_*` configurations"
|
|
||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||||
|
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
|
||||||
description=(
|
description=(
|
||||||
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
|
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
|
||||||
"recommended to enable this for large deployments."
|
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
|
||||||
),
|
),
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
|
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field(
|
||||||
|
validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"),
|
||||||
description=(
|
description=(
|
||||||
"Pub/sub channel type for streaming events. "
|
"Event transport type. Options are:\n\n"
|
||||||
"Valid options are:\n"
|
" - pubsub: normal Pub/Sub (at-most-once)\n"
|
||||||
"\n"
|
" - sharded: sharded Pub/Sub (at-most-once)\n"
|
||||||
" - pubsub: for normal Pub/Sub\n"
|
" - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n"
|
||||||
" - sharded: for sharded Pub/Sub\n"
|
"Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n"
|
||||||
"\n"
|
"Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n"
|
||||||
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
|
"the risk of data loss from Redis auto-eviction under memory pressure.\n"
|
||||||
"for large deployments."
|
"Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE."
|
||||||
),
|
),
|
||||||
default="pubsub",
|
default="pubsub",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PUBSUB_STREAMS_RETENTION_SECONDS: int = Field(
|
||||||
|
validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"),
|
||||||
|
description=(
|
||||||
|
"When using 'streams', expire each stream key this many seconds after the last event is published. "
|
||||||
|
"Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS."
|
||||||
|
),
|
||||||
|
default=600,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_default_pubsub_url(self) -> str:
|
def _build_default_pubsub_url(self) -> str:
|
||||||
defaults = self._redis_defaults()
|
defaults = self._redis_defaults()
|
||||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from dify_app import DifyApp
|
||||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||||
|
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from redis.lock import Lock
|
from redis.lock import Lock
|
||||||
|
|
@ -288,6 +289,11 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||||
|
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||||
|
return StreamsBroadcastChannel(
|
||||||
|
_pubsub_redis_client,
|
||||||
|
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||||
|
)
|
||||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||||
|
from redis import Redis, RedisCluster
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamsBroadcastChannel:
|
||||||
|
"""
|
||||||
|
Redis Streams based broadcast channel implementation.
|
||||||
|
|
||||||
|
Characteristics:
|
||||||
|
- At-least-once delivery for late subscribers within the stream retention window.
|
||||||
|
- Each topic is stored as a dedicated Redis Stream key.
|
||||||
|
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
|
||||||
|
self._client = redis_client
|
||||||
|
self._retention_seconds = max(int(retention_seconds or 0), 0)
|
||||||
|
|
||||||
|
def topic(self, topic: str) -> StreamsTopic:
|
||||||
|
return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamsTopic:
|
||||||
|
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
|
||||||
|
self._client = redis_client
|
||||||
|
self._topic = topic
|
||||||
|
self._key = f"stream:{topic}"
|
||||||
|
self._retention_seconds = retention_seconds
|
||||||
|
self.max_length = 5000
|
||||||
|
|
||||||
|
def as_producer(self) -> Producer:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def publish(self, payload: bytes) -> None:
|
||||||
|
self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length)
|
||||||
|
if self._retention_seconds > 0:
|
||||||
|
try:
|
||||||
|
self._client.expire(self._key, self._retention_seconds)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True)
|
||||||
|
|
||||||
|
def as_subscriber(self) -> Subscriber:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def subscribe(self) -> Subscription:
|
||||||
|
return _StreamsSubscription(self._client, self._key)
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamsSubscription(Subscription):
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||||
|
self._client = client
|
||||||
|
self._key = key
|
||||||
|
self._closed = threading.Event()
|
||||||
|
self._last_id = "0-0"
|
||||||
|
self._queue: queue.Queue[object] = queue.Queue()
|
||||||
|
self._start_lock = threading.Lock()
|
||||||
|
self._listener: threading.Thread | None = None
|
||||||
|
|
||||||
|
def _listen(self) -> None:
|
||||||
|
try:
|
||||||
|
while not self._closed.is_set():
|
||||||
|
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
|
||||||
|
|
||||||
|
if not streams:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for _key, entries in streams:
|
||||||
|
for entry_id, fields in entries:
|
||||||
|
data = None
|
||||||
|
if isinstance(fields, dict):
|
||||||
|
data = fields.get(b"data")
|
||||||
|
data_bytes: bytes | None = None
|
||||||
|
if isinstance(data, str):
|
||||||
|
data_bytes = data.encode()
|
||||||
|
elif isinstance(data, (bytes, bytearray)):
|
||||||
|
data_bytes = bytes(data)
|
||||||
|
if data_bytes is not None:
|
||||||
|
self._queue.put_nowait(data_bytes)
|
||||||
|
self._last_id = entry_id
|
||||||
|
finally:
|
||||||
|
self._queue.put_nowait(self._SENTINEL)
|
||||||
|
self._listener = None
|
||||||
|
|
||||||
|
def _start_if_needed(self) -> None:
|
||||||
|
if self._listener is not None:
|
||||||
|
return
|
||||||
|
# Ensure only one listener thread is created under concurrent calls
|
||||||
|
with self._start_lock:
|
||||||
|
if self._listener is not None or self._closed.is_set():
|
||||||
|
return
|
||||||
|
self._listener = threading.Thread(
|
||||||
|
target=self._listen,
|
||||||
|
name=f"redis-streams-sub-{self._key}",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._listener.start()
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[bytes]:
|
||||||
|
# Iterator delegates to receive with timeout; stops on closure.
|
||||||
|
self._start_if_needed()
|
||||||
|
while not self._closed.is_set():
|
||||||
|
item = self.receive(timeout=1)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||||
|
if self._closed.is_set():
|
||||||
|
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||||
|
self._start_if_needed()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeout is None:
|
||||||
|
item = self._queue.get()
|
||||||
|
else:
|
||||||
|
item = self._queue.get(timeout=timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if item is self._SENTINEL or self._closed.is_set():
|
||||||
|
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||||
|
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||||
|
return bytes(item)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._closed.is_set():
|
||||||
|
return
|
||||||
|
self._closed.set()
|
||||||
|
listener = self._listener
|
||||||
|
if listener is not None:
|
||||||
|
listener.join(timeout=2.0)
|
||||||
|
if listener.is_alive():
|
||||||
|
logger.warning(
|
||||||
|
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
|
||||||
|
self._key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._listener = None
|
||||||
|
|
||||||
|
# Context manager helpers
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
self._start_if_needed()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:
|
||||||
|
self.close()
|
||||||
|
return None
|
||||||
|
|
@ -38,6 +38,13 @@ if TYPE_CHECKING:
|
||||||
class AppGenerateService:
|
class AppGenerateService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
|
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
|
||||||
|
"""
|
||||||
|
Build a subscription callback that coordinates when the background task starts.
|
||||||
|
|
||||||
|
- streams transport: start immediately (events are durable; late subscribers can replay).
|
||||||
|
- pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task
|
||||||
|
still runs if the client never connects.
|
||||||
|
"""
|
||||||
started = False
|
started = False
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
|
@ -54,10 +61,18 @@ class AppGenerateService:
|
||||||
started = True
|
started = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
|
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
|
||||||
# The Celery task may publish the first event before the API side actually subscribes,
|
if channel_type == "streams":
|
||||||
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
|
# With Redis Streams, we can safely start right away; consumers can read past events.
|
||||||
# but also use a short fallback timer so the task still runs if the client never consumes.
|
_try_start()
|
||||||
|
|
||||||
|
# Keep return type Callable[[], None] consistent while allowing an extra (no-op) call.
|
||||||
|
def _on_subscribe_streams() -> None:
|
||||||
|
_try_start()
|
||||||
|
|
||||||
|
return _on_subscribe_streams
|
||||||
|
|
||||||
|
# Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback.
|
||||||
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
|
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
|
||||||
timer.daemon = True
|
timer.daemon = True
|
||||||
timer.start()
|
timer.start()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libs.broadcast_channel.redis.streams_channel import (
|
||||||
|
StreamsBroadcastChannel,
|
||||||
|
StreamsTopic,
|
||||||
|
_StreamsSubscription,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeStreamsRedis:
|
||||||
|
"""Minimal in-memory Redis Streams stub for unit tests.
|
||||||
|
|
||||||
|
- Stores entries per key as [(id, {b"data": bytes}), ...]
|
||||||
|
- xadd appends entries and returns an auto-increment id like "1-0"
|
||||||
|
- xread returns entries strictly greater than last_id
|
||||||
|
- expire is recorded but has no effect on behavior
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._store: dict[str, list[tuple[str, dict]]] = {}
|
||||||
|
self._next_id: dict[str, int] = {}
|
||||||
|
self._expire_calls: dict[str, int] = {}
|
||||||
|
|
||||||
|
# Publisher API
|
||||||
|
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||||
|
"""Append entry to stream; accept optional maxlen for API compatibility.
|
||||||
|
|
||||||
|
The test double ignores maxlen trimming semantics; only records the entry.
|
||||||
|
"""
|
||||||
|
n = self._next_id.get(key, 0) + 1
|
||||||
|
self._next_id[key] = n
|
||||||
|
entry_id = f"{n}-0"
|
||||||
|
self._store.setdefault(key, []).append((entry_id, fields))
|
||||||
|
return entry_id
|
||||||
|
|
||||||
|
def expire(self, key: str, seconds: int) -> None:
|
||||||
|
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
|
||||||
|
|
||||||
|
# Consumer API
|
||||||
|
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||||
|
# Expect a single key
|
||||||
|
assert len(streams) == 1
|
||||||
|
key, last_id = next(iter(streams.items()))
|
||||||
|
entries = self._store.get(key, [])
|
||||||
|
|
||||||
|
# Find position strictly greater than last_id
|
||||||
|
start_idx = 0
|
||||||
|
if last_id != "0-0":
|
||||||
|
for i, (eid, _f) in enumerate(entries):
|
||||||
|
if eid == last_id:
|
||||||
|
start_idx = i + 1
|
||||||
|
break
|
||||||
|
if start_idx >= len(entries):
|
||||||
|
# Simulate blocking wait (bounded) if requested
|
||||||
|
if block and block > 0:
|
||||||
|
time.sleep(min(0.01, block / 1000.0))
|
||||||
|
return []
|
||||||
|
|
||||||
|
end_idx = len(entries) if count is None else min(len(entries), start_idx + count)
|
||||||
|
batch = entries[start_idx:end_idx]
|
||||||
|
return [(key, batch)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_redis() -> FakeStreamsRedis:
|
||||||
|
return FakeStreamsRedis()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel:
|
||||||
|
return StreamsBroadcastChannel(fake_redis, retention_seconds=60)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamsBroadcastChannel:
|
||||||
|
def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis):
|
||||||
|
topic = streams_channel.topic("alpha")
|
||||||
|
assert isinstance(topic, StreamsTopic)
|
||||||
|
assert topic._client is fake_redis
|
||||||
|
assert topic._topic == "alpha"
|
||||||
|
assert topic._key == "stream:alpha"
|
||||||
|
|
||||||
|
def test_publish_calls_xadd_and_expire(
|
||||||
|
self,
|
||||||
|
streams_channel: StreamsBroadcastChannel,
|
||||||
|
fake_redis: FakeStreamsRedis,
|
||||||
|
):
|
||||||
|
topic = streams_channel.topic("beta")
|
||||||
|
payload = b"hello"
|
||||||
|
topic.publish(payload)
|
||||||
|
# One entry stored under stream key (bytes key for payload field)
|
||||||
|
assert fake_redis._store["stream:beta"][0][1] == {b"data": payload}
|
||||||
|
# Expire called after publish
|
||||||
|
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamsSubscription:
|
||||||
|
def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel):
|
||||||
|
topic = streams_channel.topic("gamma")
|
||||||
|
# Pre-publish events before subscribing (late subscriber)
|
||||||
|
topic.publish(b"e1")
|
||||||
|
topic.publish(b"e2")
|
||||||
|
|
||||||
|
sub = topic.subscribe()
|
||||||
|
assert isinstance(sub, _StreamsSubscription)
|
||||||
|
|
||||||
|
received: list[bytes] = []
|
||||||
|
with sub:
|
||||||
|
# Give listener thread a moment to xread
|
||||||
|
time.sleep(0.05)
|
||||||
|
# Drain using receive() to avoid indefinite iteration in tests
|
||||||
|
for _ in range(5):
|
||||||
|
msg = sub.receive(timeout=0.1)
|
||||||
|
if msg is None:
|
||||||
|
break
|
||||||
|
received.append(msg)
|
||||||
|
|
||||||
|
assert received == [b"e1", b"e2"]
|
||||||
|
|
||||||
|
def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel):
|
||||||
|
topic = streams_channel.topic("delta")
|
||||||
|
sub = topic.subscribe()
|
||||||
|
with sub:
|
||||||
|
# No messages yet
|
||||||
|
assert sub.receive(timeout=0.05) is None
|
||||||
|
|
||||||
|
def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel):
|
||||||
|
topic = streams_channel.topic("epsilon")
|
||||||
|
sub = topic.subscribe()
|
||||||
|
with sub:
|
||||||
|
# Listener running; now close and ensure no crash
|
||||||
|
sub.close()
|
||||||
|
# After close, receive should raise SubscriptionClosedError
|
||||||
|
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||||
|
|
||||||
|
with pytest.raises(SubscriptionClosedError):
|
||||||
|
sub.receive()
|
||||||
|
|
||||||
|
def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis):
|
||||||
|
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0)
|
||||||
|
topic = channel.topic("zeta")
|
||||||
|
topic.publish(b"payload")
|
||||||
|
# No expire recorded when retention is disabled
|
||||||
|
assert fake_redis._expire_calls.get("stream:zeta") is None
|
||||||
|
|
@ -0,0 +1,197 @@
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.apps.message_generator import MessageGenerator
|
||||||
|
from models.model import AppMode
|
||||||
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Fakes for Redis Pub/Sub flow
|
||||||
|
# -----------------------------
|
||||||
|
class _FakePubSub:
|
||||||
|
def __init__(self, store: dict[str, deque[bytes]]):
|
||||||
|
self._store = store
|
||||||
|
self._subs: set[str] = set()
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
def subscribe(self, topic: str) -> None:
|
||||||
|
self._subs.add(topic)
|
||||||
|
|
||||||
|
def unsubscribe(self, topic: str) -> None:
|
||||||
|
self._subs.discard(topic)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1):
|
||||||
|
# simulate a non-blocking poll; return first available
|
||||||
|
if self._closed:
|
||||||
|
return None
|
||||||
|
for t in list(self._subs):
|
||||||
|
q = self._store.get(t)
|
||||||
|
if q and len(q) > 0:
|
||||||
|
payload = q.popleft()
|
||||||
|
return {"type": "message", "channel": t, "data": payload}
|
||||||
|
# no message
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRedisClient:
|
||||||
|
def __init__(self, store: dict[str, deque[bytes]]):
|
||||||
|
self._store = store
|
||||||
|
|
||||||
|
def pubsub(self):
|
||||||
|
return _FakePubSub(self._store)
|
||||||
|
|
||||||
|
def publish(self, topic: str, payload: bytes) -> None:
|
||||||
|
self._store.setdefault(topic, deque()).append(payload)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
# Fakes for Redis Streams (XADD/XREAD)
|
||||||
|
# ------------------------------------
|
||||||
|
class _FakeStreams:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# key -> list[(id, {field: value})]
|
||||||
|
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
|
||||||
|
self._seq: dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||||
|
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
|
||||||
|
self._seq[key] += 1
|
||||||
|
eid = f"{self._seq[key]}-0"
|
||||||
|
self._data[key].append((eid, fields))
|
||||||
|
return eid
|
||||||
|
|
||||||
|
def expire(self, key: str, seconds: int) -> None:
|
||||||
|
# no-op for tests
|
||||||
|
return None
|
||||||
|
|
||||||
|
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||||
|
assert len(streams) == 1
|
||||||
|
key, last_id = next(iter(streams.items()))
|
||||||
|
entries = self._data.get(key, [])
|
||||||
|
start = 0
|
||||||
|
if last_id != "0-0":
|
||||||
|
for i, (eid, _f) in enumerate(entries):
|
||||||
|
if eid == last_id:
|
||||||
|
start = i + 1
|
||||||
|
break
|
||||||
|
if start >= len(entries):
|
||||||
|
return []
|
||||||
|
end = len(entries) if count is None else min(len(entries), start + count)
|
||||||
|
return [(key, entries[start:end])]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _patch_get_channel_streams(monkeypatch):
|
||||||
|
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||||
|
|
||||||
|
fake = _FakeStreams()
|
||||||
|
chan = StreamsBroadcastChannel(fake, retention_seconds=60)
|
||||||
|
|
||||||
|
def _get_channel():
|
||||||
|
return chan
|
||||||
|
|
||||||
|
# Patch both the source and the imported alias used by MessageGenerator
|
||||||
|
monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan)
|
||||||
|
monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan)
|
||||||
|
# Ensure AppGenerateService sees streams mode
|
||||||
|
import services.app_generate_service as ags
|
||||||
|
|
||||||
|
monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _patch_get_channel_pubsub(monkeypatch):
|
||||||
|
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||||
|
|
||||||
|
store: dict[str, deque[bytes]] = defaultdict(deque)
|
||||||
|
client = _FakeRedisClient(store)
|
||||||
|
chan = RedisBroadcastChannel(client)
|
||||||
|
|
||||||
|
def _get_channel():
|
||||||
|
return chan
|
||||||
|
|
||||||
|
# Patch both the source and the imported alias used by MessageGenerator
|
||||||
|
monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan)
|
||||||
|
monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan)
|
||||||
|
# Ensure AppGenerateService sees pubsub mode
|
||||||
|
import services.app_generate_service as ags
|
||||||
|
|
||||||
|
monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]):
|
||||||
|
# Publish events to the same topic used by MessageGenerator
|
||||||
|
topic = MessageGenerator.get_response_topic(app_mode, run_id)
|
||||||
|
for ev in events:
|
||||||
|
topic.publish(json.dumps(ev).encode())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_patch_get_channel_streams")
|
||||||
|
def test_streams_full_flow_prepublish_and_replay():
|
||||||
|
app_mode = AppMode.WORKFLOW
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Build start_task that publishes two events immediately
|
||||||
|
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
|
||||||
|
|
||||||
|
def start_task():
|
||||||
|
_publish_events(app_mode, run_id, events)
|
||||||
|
|
||||||
|
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
|
||||||
|
|
||||||
|
# Start retrieving BEFORE subscription is established; in streams mode, we also started immediately
|
||||||
|
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
|
||||||
|
|
||||||
|
received = []
|
||||||
|
for msg in gen:
|
||||||
|
if isinstance(msg, str):
|
||||||
|
# skip ping events
|
||||||
|
continue
|
||||||
|
received.append(msg)
|
||||||
|
if msg.get("event") == "workflow_finished":
|
||||||
|
break
|
||||||
|
|
||||||
|
assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_patch_get_channel_pubsub")
|
||||||
|
def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch):
|
||||||
|
# Speed up any potential timer if it accidentally triggers
|
||||||
|
monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50)
|
||||||
|
|
||||||
|
app_mode = AppMode.WORKFLOW
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
published_order: list[str] = []
|
||||||
|
|
||||||
|
def start_task():
|
||||||
|
# When called (on subscribe), publish both events
|
||||||
|
events = [{"event": "workflow_started"}, {"event": "workflow_finished"}]
|
||||||
|
_publish_events(app_mode, run_id, events)
|
||||||
|
published_order.extend([e["event"] for e in events])
|
||||||
|
|
||||||
|
on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task)
|
||||||
|
|
||||||
|
# Producer not started yet; only when subscribe happens
|
||||||
|
assert published_order == []
|
||||||
|
|
||||||
|
gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe)
|
||||||
|
|
||||||
|
received = []
|
||||||
|
for msg in gen:
|
||||||
|
if isinstance(msg, str):
|
||||||
|
continue
|
||||||
|
received.append(msg)
|
||||||
|
if msg.get("event") == "workflow_finished":
|
||||||
|
break
|
||||||
|
|
||||||
|
# Verify publish happened and consumer received in order
|
||||||
|
assert published_order == ["workflow_started", "workflow_finished"]
|
||||||
|
assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"]
|
||||||
Loading…
Reference in New Issue