mirror of https://github.com/langgenius/dify.git
Merge branch 'fix/redis-pubsub-perf' into feat/hitl
This commit is contained in:
commit
3d0ff9463f
|
|
@ -7,6 +7,7 @@ from typing import Self
|
|||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
|
@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class Topic:
|
|||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
|
@ -63,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
|
|||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
redis_client: Redis | RedisCluster,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
|
|||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
|
|
@ -40,6 +40,7 @@ class ShardedTopic:
|
|||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisShardedSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
|
@ -68,7 +69,19 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
||||
if isinstance(self._client, RedisCluster):
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
|
||||
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
else:
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ class TestShardedTopic:
|
|||
subscription = sharded_topic.subscribe()
|
||||
|
||||
assert isinstance(subscription, _RedisShardedSubscription)
|
||||
assert subscription._client is mock_redis_client
|
||||
assert subscription._pubsub is mock_redis_client.pubsub.return_value
|
||||
assert subscription._topic == "test-sharded-topic"
|
||||
|
||||
|
|
@ -200,6 +201,11 @@ class SubscriptionTestCase:
|
|||
class TestRedisSubscription:
|
||||
"""Test cases for the _RedisSubscription class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
|
|
@ -211,9 +217,12 @@ class TestRedisSubscription:
|
|||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
|
||||
def subscription(
|
||||
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
) -> Generator[_RedisSubscription, None, None]:
|
||||
"""Create a _RedisSubscription instance for testing."""
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
|
@ -228,13 +237,15 @@ class TestRedisSubscription:
|
|||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||
def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test that subscription is properly initialized."""
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
||||
assert subscription._client is mock_redis_client
|
||||
assert subscription._pubsub is mock_pubsub
|
||||
assert subscription._topic == "test-topic"
|
||||
assert not subscription._closed.is_set()
|
||||
|
|
@ -486,9 +497,12 @@ class TestRedisSubscription:
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||
def test_subscription_scenarios(
|
||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
):
|
||||
"""Test various subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
|
@ -572,7 +586,7 @@ class TestRedisSubscription:
|
|||
# Close should still work
|
||||
subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test various channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
|
|
@ -586,6 +600,7 @@ class TestRedisSubscription:
|
|||
|
||||
for channel_name in channel_names:
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic=channel_name,
|
||||
)
|
||||
|
|
@ -604,6 +619,11 @@ class TestRedisSubscription:
|
|||
class TestRedisShardedSubscription:
|
||||
"""Test cases for the _RedisShardedSubscription class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
|
|
@ -615,9 +635,12 @@ class TestRedisShardedSubscription:
|
|||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]:
|
||||
def sharded_subscription(
|
||||
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
) -> Generator[_RedisShardedSubscription, None, None]:
|
||||
"""Create a _RedisShardedSubscription instance for testing."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
|
|
@ -634,13 +657,15 @@ class TestRedisShardedSubscription:
|
|||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test that sharded subscription is properly initialized."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
|
||||
assert subscription._client is mock_redis_client
|
||||
assert subscription._pubsub is mock_pubsub
|
||||
assert subscription._topic == "test-sharded-topic"
|
||||
assert not subscription._closed.is_set()
|
||||
|
|
@ -808,6 +833,37 @@ class TestRedisShardedSubscription:
|
|||
assert not sharded_subscription._queue.empty()
|
||||
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
|
||||
|
||||
def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch):
|
||||
"""Test that cluster clients use target_node for sharded messages."""
|
||||
|
||||
class DummyRedisCluster:
|
||||
def __init__(self):
|
||||
self.get_node_from_key = MagicMock(return_value="node-1")
|
||||
|
||||
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster)
|
||||
|
||||
client = DummyRedisCluster()
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
mock_pubsub.get_sharded_message.return_value = {
|
||||
"type": "smessage",
|
||||
"channel": "test-sharded-topic",
|
||||
"data": b"payload",
|
||||
}
|
||||
|
||||
result = subscription._get_message()
|
||||
|
||||
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
|
||||
mock_pubsub.get_sharded_message.assert_called_once_with(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=0.1,
|
||||
target_node="node-1",
|
||||
)
|
||||
assert result == mock_pubsub.get_sharded_message.return_value
|
||||
|
||||
def test_listener_thread_ignores_subscribe_messages(
|
||||
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
|
||||
):
|
||||
|
|
@ -913,9 +969,12 @@ class TestRedisShardedSubscription:
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||
def test_sharded_subscription_scenarios(
|
||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
):
|
||||
"""Test various sharded subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
|
|
@ -999,7 +1058,7 @@ class TestRedisShardedSubscription:
|
|||
# Close should still work
|
||||
sharded_subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test various sharded channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
|
|
@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription:
|
|||
|
||||
for channel_name in channel_names:
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic=channel_name,
|
||||
)
|
||||
|
|
@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon:
|
|||
"""Parameterized fixture providing subscription type and class."""
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
|
|
@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon:
|
|||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, subscription_params, mock_pubsub: MagicMock):
|
||||
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Create a subscription instance based on parameterized type."""
|
||||
subscription_type, subscription_class = subscription_params
|
||||
topic_name = f"test-{subscription_type}-topic"
|
||||
subscription = subscription_class(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic=topic_name,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue