Merge branch 'fix/redis-pubsub-perf' into feat/hitl

This commit is contained in:
QuantumGhost 2026-02-06 14:42:39 +08:00
commit 3d0ff9463f
4 changed files with 97 additions and 14 deletions

View File

@ -7,6 +7,7 @@ from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
from redis.client import PubSub
_logger = logging.getLogger(__name__)
@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
def __init__(
self,
client: Redis | RedisCluster,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._client = client
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()

View File

@ -42,6 +42,7 @@ class Topic:
def subscribe(self) -> Subscription:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
)
@ -63,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
def _get_message_type(self) -> str:
return "message"

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from redis import Redis, RedisCluster
from ._subscription import RedisSubscriptionBase
@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
def __init__(
self,
redis_client: Redis,
redis_client: Redis | RedisCluster,
):
self._client = redis_client
@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
class ShardedTopic:
def __init__(self, redis_client: Redis, topic: str):
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
@ -40,6 +40,7 @@ class ShardedTopic:
def subscribe(self) -> Subscription:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
)
@ -68,7 +69,19 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
#
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
if isinstance(self._client, RedisCluster):
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message(
ignore_subscribe_messages=False,
timeout=1,
target_node=node,
)
else:
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"

View File

@ -181,6 +181,7 @@ class TestShardedTopic:
subscription = sharded_topic.subscribe()
assert isinstance(subscription, _RedisShardedSubscription)
assert subscription._client is mock_redis_client
assert subscription._pubsub is mock_redis_client.pubsub.return_value
assert subscription._topic == "test-sharded-topic"
@ -200,6 +201,11 @@ class SubscriptionTestCase:
class TestRedisSubscription:
"""Test cases for the _RedisSubscription class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
@ -211,9 +217,12 @@ class TestRedisSubscription:
return pubsub
@pytest.fixture
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
def subscription(
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
) -> Generator[_RedisSubscription, None, None]:
"""Create a _RedisSubscription instance for testing."""
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-topic",
)
@ -228,13 +237,15 @@ class TestRedisSubscription:
# ==================== Lifecycle Tests ====================
def test_subscription_initialization(self, mock_pubsub: MagicMock):
def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test that subscription is properly initialized."""
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-topic",
)
assert subscription._client is mock_redis_client
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-topic"
assert not subscription._closed.is_set()
@ -486,9 +497,12 @@ class TestRedisSubscription:
),
],
)
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
def test_subscription_scenarios(
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
):
"""Test various subscription scenarios using table-driven approach."""
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-topic",
)
@ -572,7 +586,7 @@ class TestRedisSubscription:
# Close should still work
subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test various channel name formats."""
channel_names = [
"simple",
@ -586,6 +600,7 @@ class TestRedisSubscription:
for channel_name in channel_names:
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic=channel_name,
)
@ -604,6 +619,11 @@ class TestRedisSubscription:
class TestRedisShardedSubscription:
"""Test cases for the _RedisShardedSubscription class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
@ -615,9 +635,12 @@ class TestRedisShardedSubscription:
return pubsub
@pytest.fixture
def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]:
def sharded_subscription(
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
) -> Generator[_RedisShardedSubscription, None, None]:
"""Create a _RedisShardedSubscription instance for testing."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
@ -634,13 +657,15 @@ class TestRedisShardedSubscription:
# ==================== Lifecycle Tests ====================
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock):
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test that sharded subscription is properly initialized."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
assert subscription._client is mock_redis_client
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-sharded-topic"
assert not subscription._closed.is_set()
@ -808,6 +833,37 @@ class TestRedisShardedSubscription:
assert not sharded_subscription._queue.empty()
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch):
"""Test that cluster clients use target_node for sharded messages."""
class DummyRedisCluster:
def __init__(self):
self.get_node_from_key = MagicMock(return_value="node-1")
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster)
client = DummyRedisCluster()
subscription = _RedisShardedSubscription(
client=client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
mock_pubsub.get_sharded_message.return_value = {
"type": "smessage",
"channel": "test-sharded-topic",
"data": b"payload",
}
result = subscription._get_message()
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
mock_pubsub.get_sharded_message.assert_called_once_with(
ignore_subscribe_messages=False,
timeout=0.1,
target_node="node-1",
)
assert result == mock_pubsub.get_sharded_message.return_value
def test_listener_thread_ignores_subscribe_messages(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
@ -913,9 +969,12 @@ class TestRedisShardedSubscription:
),
],
)
def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
def test_sharded_subscription_scenarios(
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
):
"""Test various sharded subscription scenarios using table-driven approach."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
@ -999,7 +1058,7 @@ class TestRedisShardedSubscription:
# Close should still work
sharded_subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test various sharded channel name formats."""
channel_names = [
"simple",
@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription:
for channel_name in channel_names:
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic=channel_name,
)
@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon:
"""Parameterized fixture providing subscription type and class."""
return request.param
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon:
return pubsub
@pytest.fixture
def subscription(self, subscription_params, mock_pubsub: MagicMock):
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Create a subscription instance based on parameterized type."""
subscription_type, subscription_class = subscription_params
topic_name = f"test-{subscription_type}-topic"
subscription = subscription_class(
client=mock_redis_client,
pubsub=mock_pubsub,
topic=topic_name,
)