diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index fa2be421a1..df81775660 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -7,6 +7,7 @@ from typing import Self from libs.broadcast_channel.channel import Subscription from libs.broadcast_channel.exc import SubscriptionClosedError +from redis import Redis, RedisCluster from redis.client import PubSub _logger = logging.getLogger(__name__) @@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription): def __init__( self, + client: Redis | RedisCluster, pubsub: PubSub, topic: str, ): # The _pubsub is None only if the subscription is closed. + self._client = client self._pubsub: PubSub | None = pubsub self._topic = topic self._closed = threading.Event() diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 5bb4f579c1..35a227769c 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -42,6 +42,7 @@ class Topic: def subscribe(self) -> Subscription: return _RedisSubscription( + client=self._client, pubsub=self._client.pubsub(), topic=self._topic, ) @@ -63,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase): def _get_message(self) -> dict | None: assert self._pubsub is not None - return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) + return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1) def _get_message_type(self) -> str: return "message" diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 9e8ab90e8e..8d4d7873ca 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,7 +1,7 @@ from __future__ import annotations from libs.broadcast_channel.channel import Producer, Subscriber, Subscription -from redis import Redis +from redis import Redis, RedisCluster from ._subscription import RedisSubscriptionBase @@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel: def __init__( self, - redis_client: Redis, + redis_client: Redis | RedisCluster, ): self._client = redis_client @@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel: class ShardedTopic: - def __init__(self, redis_client: Redis, topic: str): + def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic @@ -40,6 +40,7 @@ class ShardedTopic: def subscribe(self) -> Subscription: return _RedisShardedSubscription( + client=self._client, pubsub=self._client.pubsub(), topic=self._topic, ) @@ -68,7 +69,19 @@ class _RedisShardedSubscription(RedisSubscriptionBase): # # Since we have already filtered at the caller's site, we can safely set # `ignore_subscribe_messages=False`. - return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined] + if isinstance(self._client, RedisCluster): + # NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message` + # would use busy-looping to wait for incoming message, consuming excessive CPU quota. + # + # Here we specify the `target_node` to mitigate this problem. + node = self._client.get_node_from_key(self._topic) + return self._pubsub.get_sharded_message( + ignore_subscribe_messages=False, + timeout=1, + target_node=node, + ) + else: + return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined] def _get_message_type(self) -> str: return "smessage" diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index ccba075fdf..54bb9954d5 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -181,6 +181,7 @@ class TestShardedTopic: subscription = sharded_topic.subscribe() assert isinstance(subscription, _RedisShardedSubscription) + assert subscription._client is mock_redis_client assert subscription._pubsub is mock_redis_client.pubsub.return_value assert subscription._topic == "test-sharded-topic" @@ -200,6 +201,11 @@ class SubscriptionTestCase: class TestRedisSubscription: """Test cases for the _RedisSubscription class.""" + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + client = MagicMock() + return client + @pytest.fixture def mock_pubsub(self) -> MagicMock: """Create a mock PubSub instance for testing.""" @@ -211,9 +217,12 @@ class TestRedisSubscription: return pubsub @pytest.fixture - def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]: + def subscription( + self, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ) -> Generator[_RedisSubscription, None, None]: """Create a _RedisSubscription instance for testing.""" subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-topic", ) @@ -228,13 +237,15 @@ class TestRedisSubscription: # ==================== Lifecycle Tests ==================== - def test_subscription_initialization(self, mock_pubsub: MagicMock): + def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test that subscription is properly initialized.""" subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-topic", ) + assert subscription._client is mock_redis_client assert subscription._pubsub is mock_pubsub assert subscription._topic == "test-topic" assert not subscription._closed.is_set() @@ -486,9 +497,12 @@ class TestRedisSubscription: ), ], ) - def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock): + def test_subscription_scenarios( + self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ): """Test various subscription scenarios using table-driven approach.""" subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-topic", ) @@ -572,7 +586,7 @@ class TestRedisSubscription: # Close should still work subscription.close() # Should not raise - def test_channel_name_variations(self, mock_pubsub: MagicMock): + def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test various channel name formats.""" channel_names = [ "simple", @@ -586,6 +600,7 @@ class TestRedisSubscription: for channel_name in channel_names: subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic=channel_name, ) @@ -604,6 +619,11 @@ class TestRedisSubscription: class TestRedisShardedSubscription: """Test cases for the _RedisShardedSubscription class.""" + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + client = MagicMock() + return client + @pytest.fixture def mock_pubsub(self) -> MagicMock: """Create a mock PubSub instance for testing.""" @@ -615,9 +635,12 @@ class TestRedisShardedSubscription: return pubsub @pytest.fixture - def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]: + def sharded_subscription( + self, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ) -> Generator[_RedisShardedSubscription, None, None]: """Create a _RedisShardedSubscription instance for testing.""" subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-sharded-topic", ) @@ -634,13 +657,15 @@ class TestRedisShardedSubscription: # ==================== Lifecycle Tests ==================== - def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock): + def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test that sharded subscription is properly initialized.""" subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-sharded-topic", ) + assert subscription._client is mock_redis_client assert subscription._pubsub is mock_pubsub assert subscription._topic == "test-sharded-topic" assert not subscription._closed.is_set() @@ -808,6 +833,37 @@ class TestRedisShardedSubscription: assert not sharded_subscription._queue.empty() assert sharded_subscription._queue.get_nowait() == b"test sharded payload" + def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch): + """Test that cluster clients use target_node for sharded messages.""" + + class DummyRedisCluster: + def __init__(self): + self.get_node_from_key = MagicMock(return_value="node-1") + + monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster) + + client = DummyRedisCluster() + subscription = _RedisShardedSubscription( + client=client, + pubsub=mock_pubsub, + topic="test-sharded-topic", + ) + mock_pubsub.get_sharded_message.return_value = { + "type": "smessage", + "channel": "test-sharded-topic", + "data": b"payload", + } + + result = subscription._get_message() + + client.get_node_from_key.assert_called_once_with("test-sharded-topic") + mock_pubsub.get_sharded_message.assert_called_once_with( + ignore_subscribe_messages=False, + timeout=0.1, + target_node="node-1", + ) + assert result == mock_pubsub.get_sharded_message.return_value + def test_listener_thread_ignores_subscribe_messages( self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock ): @@ -913,9 +969,12 @@ class TestRedisShardedSubscription: ), ], ) - def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock): + def test_sharded_subscription_scenarios( + self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ): """Test various sharded subscription scenarios using table-driven approach.""" subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-sharded-topic", ) @@ -999,7 +1058,7 @@ class TestRedisShardedSubscription: # Close should still work sharded_subscription.close() # Should not raise - def test_channel_name_variations(self, mock_pubsub: MagicMock): + def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test various sharded channel name formats.""" channel_names = [ "simple", @@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription: for channel_name in channel_names: subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic=channel_name, ) @@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon: """Parameterized fixture providing subscription type and class.""" return request.param + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + client = MagicMock() + return client + @pytest.fixture def mock_pubsub(self) -> MagicMock: """Create a mock PubSub instance for testing.""" @@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon: return pubsub @pytest.fixture - def subscription(self, subscription_params, mock_pubsub: MagicMock): + def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Create a subscription instance based on parameterized type.""" subscription_type, subscription_class = subscription_params topic_name = f"test-{subscription_type}-topic" subscription = subscription_class( + client=mock_redis_client, pubsub=mock_pubsub, topic=topic_name, )