diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 880edde41f..225dbdedee 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -46,6 +46,7 @@ jobs: GHSA-w596-4wvx-j9j6 # subversion related git pull, dependency for pytest. There is no impact here. CVE-2026-26007 # dependency for entraid tests CVE-2026-32597 # PyJWT does not validate the crit (Critical) Header Parameter defined in RFC 7515, this will be fixed in the next release + CVE-2026-4539 # pygments ReDoS in AdlLexer, local access only, no fix available yet lint: name: Code linters diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 54f8acfe89..a805c22c54 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -981,6 +981,8 @@ def __init__( self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() self._lock = asyncio.Lock() async def __aenter__(self): @@ -1009,6 +1011,8 @@ async def aclose(self): self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") async def close(self) -> None: @@ -1033,6 +1037,7 @@ async def on_connect(self, connection: Connection): # that no decoding is required. self.pending_unsubscribe_channels.clear() self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() if self.channels: channels_with_handlers = {} channels_without_handlers = [] @@ -1057,11 +1062,21 @@ async def on_connect(self, connection: Connection): await self.psubscribe( *patterns_without_handlers, **patterns_with_handlers ) + if self.shard_channels: + shard_with_handlers = {} + shard_without_handlers = [] + for k, v in self.shard_channels.items(): + if v is not None: + shard_with_handlers[self.encoder.decode(k, force=True)] = v + else: + shard_without_handlers.append(k) + if shard_with_handlers or shard_without_handlers: + await self.ssubscribe(*shard_without_handlers, **shard_with_handlers) @property def subscribed(self): """Indicates if there are subscriptions to any channels or patterns""" - return bool(self.channels or self.patterns) + return bool(self.channels or self.patterns or self.shard_channels) async def execute_command(self, *args: EncodableT): """Execute a publish/subscribe command""" @@ -1341,6 +1356,40 @@ def unsubscribe(self, *args) -> Awaitable: self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *parsed_args) + async def ssubscribe(self, *args, target_node=None, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + ret_val = await self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None) -> Awaitable: + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + return self.execute_command("SUNSUBSCRIBE", *args) + async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" while self.subscribed: @@ -1412,6 +1461,13 @@ async def handle_message(self, response, ignore_subscribe_messages=False): direction=PubSubDirection.RECEIVE, channel=channel, ) + elif message_type == "smessage": + channel = str_if_bytes(message["channel"]) + await record_pubsub_message( + direction=PubSubDirection.RECEIVE, + channel=channel, + sharded=True, + ) # if this is an unsubscribe message, remove it from memory if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: @@ -1420,6 +1476,11 @@ async def handle_message(self, response, ignore_subscribe_messages=False): if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) else: channel = response[1] if channel in self.pending_unsubscribe_channels: @@ -1430,6 +1491,8 @@ async def handle_message(self, response, ignore_subscribe_messages=False): # if there's a message handler, invoke it if message_type == "pmessage": handler = self.patterns.get(message["pattern"], None) + elif message_type == "smessage": + handler = self.shard_channels.get(message["channel"], None) else: handler = self.channels.get(message["channel"], None) if handler: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 3d416d6973..882662bfbb 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -33,8 +33,14 @@ _RedisCallbacksRESP2, _RedisCallbacksRESP3, ) -from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import Connection, SSLConnection, parse_url +from redis.asyncio.client import PubSub, ResponseCallbackT +from redis.asyncio.connection import ( + AbstractConnection, + Connection, + ConnectionPool, + SSLConnection, + parse_url, +) from redis.asyncio.lock import Lock from redis.asyncio.observability.recorder import ( record_error_count, @@ -57,6 +63,7 @@ parse_cluster_slots, ) from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands +from redis.commands.helpers import list_or_args from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.credentials import CredentialProvider @@ -1222,6 +1229,27 @@ def pipeline( return ClusterPipeline(self, transaction) + def pubsub( + self, + node: Optional["ClusterNode"] = None, + host: Optional[str] = None, + port: Optional[int] = None, + **kwargs: Any, + ) -> "ClusterPubSub": + """ + Create and return a ClusterPubSub instance. + + Allows passing a ClusterNode, or host&port, to get a pubsub instance + connected to the specified node + + :param node: ClusterNode to connect to + :param host: Host of the node to connect to + :param port: Port of the node to connect to + :param kwargs: Additional keyword arguments + :return: ClusterPubSub instance + """ + return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) + def lock( self, name: KeyT, @@ -2973,3 +3001,327 @@ async def discard(self): async def unlink(self, *names): return self.execute_command("UNLINK", *names) + + +class ClusterPubSub(PubSub): + """ + Async cluster implementation for pub/sub. + + IMPORTANT: before using ClusterPubSub, read about the known limitations + with pubsub in Cluster mode and learn how to workaround them: + https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations + """ + + def __init__( + self, + redis_cluster: "RedisCluster", + node: Optional["ClusterNode"] = None, + host: Optional[str] = None, + port: Optional[int] = None, + push_handler_func: Optional[Callable] = None, + event_dispatcher: Optional[EventDispatcher] = None, + **kwargs: Any, + ) -> None: + """ + When a pubsub instance is created without specifying a node, a single + node will be transparently chosen for the pubsub connection on the + first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true or load_balancing_strategy is set, a replica can be selected. + + :param redis_cluster: RedisCluster instance + :param node: ClusterNode to connect to + :param host: Host of the node to connect to + :param port: Port of the node to connect to + :param push_handler_func: Optional push handler function + :param event_dispatcher: Optional event dispatcher + :param kwargs: Additional keyword arguments + """ + self.node = None + self.set_pubsub_node(redis_cluster, node, host, port) + + # Create connection pool if node is specified + if self.node is not None: + connection_pool = ConnectionPool( + connection_class=self.node.connection_class, + **self.node.connection_kwargs, + ) + else: + connection_pool = None + + self.cluster = redis_cluster + self.node_pubsub_mapping: Dict[str, PubSub] = {} + self._pubsubs_generator = self._pubsubs_generator() + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + super().__init__( + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, + event_dispatcher=self._event_dispatcher, + **kwargs, + ) + + def set_pubsub_node( + self, + cluster: "RedisCluster", + node: Optional["ClusterNode"] = None, + host: Optional[str] = None, + port: Optional[int] = None, + ) -> None: + """ + The pubsub node will be set according to the passed node, host and port + When none of the node, host, or port are specified - the node is set + to None and will be determined by the keyslot of the channel in the + first command to be executed. + RedisClusterException will be thrown if the passed node does not exist + in the cluster. + If host is passed without port, or vice versa, a DataError will be + thrown. + """ + if node is not None: + # node is passed by the user + self._raise_on_invalid_node(cluster, node, node.host, node.port) + pubsub_node = node + elif host is not None and port is not None: + # host and port passed by the user + node = cluster.get_node(host=host, port=port) + self._raise_on_invalid_node(cluster, node, host, port) + pubsub_node = node + elif host is not None or port is not None: + # only one of host and port is specified + raise DataError("Specify both host and port") + else: + # nothing specified by the user + pubsub_node = None + self.node = pubsub_node + + def get_pubsub_node(self) -> Optional["ClusterNode"]: + """ + Get the node that is being used as the pubsub connection. + + :return: The ClusterNode being used for pubsub, or None if not yet determined + """ + return self.node + + def _get_node_pubsub(self, node: "ClusterNode") -> PubSub: + """Get or create a PubSub instance for the given node.""" + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + # Create a minimal connection pool for this node + connection_pool = ConnectionPool( + connection_class=node.connection_class, **node.connection_kwargs + ) + + pubsub = PubSub( + connection_pool=connection_pool, + encoder=self.cluster.encoder, + push_handler_func=self.push_handler_func, + event_dispatcher=self._event_dispatcher, + ) + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + async def _sharded_message_generator( + self, timeout: float = 0.0 + ) -> Optional[Dict[str, Any]]: + """Generate messages from shard channels across all nodes.""" + for _ in range(len(self.node_pubsub_mapping)): + pubsub = next(self._pubsubs_generator) + # Don't pass ignore_subscribe_messages here - let get_sharded_message + # handle the filtering after processing subscription state changes + message = await pubsub.get_message( + ignore_subscribe_messages=False, timeout=timeout + ) + if message is not None: + return message + return None + + def _pubsubs_generator(self) -> Generator[PubSub, None, None]: + """Generator that yields PubSub instances in round-robin fashion.""" + while True: + current_nodes = list(self.node_pubsub_mapping.values()) + if not current_nodes: + return # Avoid infinite loop when no subscriptions exist + yield from current_nodes + + async def get_sharded_message( + self, + ignore_subscribe_messages: bool = False, + timeout: float = 0.0, + target_node: Optional["ClusterNode"] = None, + ) -> Optional[Dict[str, Any]]: + """ + Get a message from shard channels. + + :param ignore_subscribe_messages: Whether to ignore subscribe messages + :param timeout: Timeout for message retrieval + :param target_node: Specific node to get message from + :return: Message dictionary or None + """ + if target_node: + pubsub = self.node_pubsub_mapping.get(target_node.name) + if pubsub: + # Don't pass ignore_subscribe_messages here - let get_sharded_message + # handle the filtering after processing subscription state changes + message = await pubsub.get_message( + ignore_subscribe_messages=False, timeout=timeout + ) + else: + message = None + else: + message = await self._sharded_message_generator(timeout=timeout) + + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if node and node.name in self.node_pubsub_mapping: + pubsub = self.node_pubsub_mapping[node.name] + if not pubsub.subscribed: + self.node_pubsub_mapping.pop(node.name) + + # Only suppress subscribe/unsubscribe messages, not data messages (smessage) + if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"): + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None + return message + + async def ssubscribe(self, *args: Any, **kwargs: Any) -> None: + """ + Subscribe to shard channels. + + :param args: Channel names + :param kwargs: Channel names with handlers + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + + for s_channel, handler in s_channels.items(): + node = self.cluster.get_node_from_key(s_channel) + if node: + pubsub = self._get_node_pubsub(node) + if handler: + await pubsub.ssubscribe(**{s_channel: handler}) + else: + await pubsub.ssubscribe(s_channel) + self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) + + async def sunsubscribe(self, *args: Any) -> None: + """ + Unsubscribe from shard channels. + + :param args: Channel names to unsubscribe from. If empty, unsubscribe from all. + """ + if args: + args = list_or_args(args[0], args[1:]) + else: + args = list(self.shard_channels.keys()) + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + if node and node.name in self.node_pubsub_mapping: + pubsub = self.node_pubsub_mapping[node.name] + await pubsub.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + pubsub.pending_unsubscribe_shard_channels + ) + + def get_redis_connection(self) -> Optional["AbstractConnection"]: + """ + Get the Redis connection of the pubsub connected node. + + Returns the pubsub's dedicated connection (acquired from its own + connection pool), not from the ClusterNode's connection pool. + This avoids the connection pool resource leak that would occur + if we called node.acquire_connection() without releasing. + """ + # Return the pubsub's own dedicated connection, which is acquired + # from self.connection_pool when executing pubsub commands. + # This is safe because it's the connection dedicated to this pubsub + # instance, not a shared pool connection from the ClusterNode. + return self.connection + + async def aclose(self) -> None: + """ + Disconnect the pubsub connection. + """ + # Close all shard pubsub instances first + for pubsub in self.node_pubsub_mapping.values(): + await pubsub.aclose() + # Let parent handle self.connection disconnect under the lock + # (includes disconnect, release to pool, and clearing self.connection) + await super().aclose() + + def _raise_on_invalid_node( + self, + redis_cluster: "RedisCluster", + node: Optional["ClusterNode"], + host: Optional[str], + port: Optional[int], + ) -> None: + """ + Raise a RedisClusterException if the node is None or doesn't exist in + the cluster. + """ + if node is None or redis_cluster.get_node(node_name=node.name) is None: + raise RedisClusterException( + f"Node {host}:{port} doesn't exist in the cluster" + ) + + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: + """ + Execute a command on the appropriate cluster node. + + Taken code from redis-py and tweaked to make it work within a cluster. + """ + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + # For shard commands, route to appropriate node + command = args[0].upper() if args else "" + if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"): + if len(args) > 1: + channel = args[1] + node = self.cluster.get_node_from_key(channel) + if node: + pubsub = self._get_node_pubsub(node) + return await pubsub.execute_command(*args, **kwargs) + + # For other commands, use the set node or lazily discover one + if self.connection is None: + if self.connection_pool is None: + if len(args) > 1: + # Hash the first channel and get one of the nodes holding + # this slot + channel = args[1] + slot = self.cluster.keyslot(channel) + node = self.cluster.nodes_manager.get_node_from_slot( + slot, + self.cluster.read_from_replicas, + self.cluster.load_balancing_strategy, + ) + else: + # Get a random node + node = self.cluster.get_random_node() + self.node = node + self.connection_pool = ConnectionPool( + connection_class=node.connection_class, + **node.connection_kwargs, + ) + + # Now we have a connection_pool, use parent's execute_command + return await super().execute_command(*args, **kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index e33f5e534a..388b83496d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2528,7 +2528,7 @@ class ClusterPubSub(PubSub): IMPORTANT: before using ClusterPubSub, read about the known limitations with pubsub in Cluster mode and learn how to workaround them: - https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations """ def __init__( @@ -2692,14 +2692,18 @@ def _sharded_message_generator(self, timeout=0.0): def _pubsubs_generator(self): while True: current_nodes = list(self.node_pubsub_mapping.values()) + if not current_nodes: + return # Avoid infinite loop when no subscriptions exist yield from current_nodes def get_sharded_message( self, ignore_subscribe_messages=False, timeout=0.0, target_node=None ): if target_node: + # Don't pass ignore_subscribe_messages here - let get_sharded_message + # handle the filtering after processing subscription state changes message = self.node_pubsub_mapping[target_node.name].get_message( - ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ignore_subscribe_messages=False, timeout=timeout ) else: message = self._sharded_message_generator(timeout=timeout) @@ -2716,8 +2720,10 @@ def get_sharded_message( # There are no subscriptions anymore, set subscribed_event flag # to false self.subscribed_event.clear() - if self.ignore_subscribe_messages or ignore_subscribe_messages: - return None + # Only suppress subscribe/unsubscribe messages, not data messages (smessage) + if str_if_bytes(message["type"]) in ("ssubscribe", "sunsubscribe"): + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None return message def ssubscribe(self, *args, **kwargs): diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index dea73763da..6f196ad429 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -1199,6 +1199,7 @@ class AsyncRedisClusterCommands( AsyncClusterMultiKeyCommands, AsyncClusterManagementCommands, AsyncACLCommands, + PubSubCommands, AsyncClusterDataAccessCommands, AsyncScriptCommands, AsyncFunctionCommands, diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 1430760438..2550a59c97 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -4083,3 +4083,528 @@ async def test_multiple_pipeline_executions_record_multiple_metrics( # Should have at least 2 pipeline events assert len(pipeline_calls) >= 2 + + +@pytest.mark.onlycluster +@pytest.mark.skipif( + 'not config.REDIS_INFO.get("cluster_enabled", False)', + reason="Requires Redis Cluster", +) +class TestClusterPubSub: + """ + Test ClusterPubSub with shard channels functionality + """ + + async def wait_for_message( + self, pubsub, timeout=0.2, ignore_subscribe_messages=False, sharded=False + ): + """Helper method to wait for a message with timeout. + + Args: + pubsub: The PubSub instance + timeout: Timeout in seconds + ignore_subscribe_messages: Whether to ignore subscribe messages + sharded: If True, use get_sharded_message() instead of get_message() + """ + import asyncio + + now = asyncio.get_running_loop().time() + end_time = now + timeout + while now < end_time: + if sharded: + message = await pubsub.get_sharded_message( + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=0.01, + ) + else: + message = await pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) + if message is not None: + return message + await asyncio.sleep(0.01) + now = asyncio.get_running_loop().time() + return None + + def make_message(self, type, channel, data, pattern=None): + """Helper method to create expected message format""" + return { + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_pubsub_creation(self, r): + """Test basic ClusterPubSub creation""" + pubsub = r.pubsub() + assert pubsub is not None + assert hasattr(pubsub, "ssubscribe") + assert hasattr(pubsub, "sunsubscribe") + assert hasattr(pubsub, "get_sharded_message") + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_pubsub_with_node(self, r): + """Test ClusterPubSub creation with specific node""" + nodes = r.get_nodes() + if nodes: + node = nodes[0] + pubsub = r.pubsub(node=node) + assert pubsub.node == node + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_pubsub_with_host_port(self, r): + """Test ClusterPubSub creation with host and port""" + nodes = r.get_nodes() + if nodes: + node = nodes[0] + pubsub = r.pubsub(host=node.host, port=node.port) + assert pubsub.node.host == node.host + assert pubsub.node.port == node.port + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_subscribe_unsubscribe(self, r): + """Test shard channel subscribe and unsubscribe""" + pubsub = r.pubsub() + + try: + # Test channels that map to different nodes + channels = ["shard_test_1", "shard_test_2", "shard_test_3"] + + # Subscribe to shard channels + await pubsub.ssubscribe(*channels) + + # Verify subscription messages - one ssubscribe confirmation per channel + received_channels = set() + for _ in range(len(channels)): + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None, "Expected subscription confirmation message" + assert msg["type"] == "ssubscribe" + assert msg["channel"].decode() in channels + received_channels.add(msg["channel"].decode()) + + # Verify we got confirmations for all channels + assert received_channels == set(channels) + + # Unsubscribe from shard channels + await pubsub.sunsubscribe(*channels) + + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_attributes(self, r): + """Test shard channel attributes""" + pubsub = r.pubsub() + + try: + # Initially no shard channels + assert not pubsub.shard_channels + assert not pubsub.pending_unsubscribe_shard_channels + + # Subscribe to a shard channel + await pubsub.ssubscribe("test_shard_attr") + + # Should have shard channel information + # Note: The exact behavior may depend on implementation details + + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_with_handler(self, r): + """Test shard channel subscription with message handler""" + pubsub = r.pubsub() + + try: + received_messages = [] + + def message_handler(message): + received_messages.append(message) + + # Subscribe with handler + await pubsub.ssubscribe(test_handler_channel=message_handler) + + # This test verifies that the handler mechanism is properly set up + # Actual message delivery testing would require a live Redis cluster + + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_invalid_node_raises_exception(self, r): + """Test that invalid node raises appropriate exception""" + with pytest.raises(RedisClusterException): + r.pubsub(host="invalid_host", port=9999) + + @skip_if_server_version_lt("7.0.0") + async def test_partial_host_port_raises_exception(self, r): + """Test that providing only host or port raises DataError""" + with pytest.raises(DataError): + r.pubsub(host="localhost") # Missing port + + with pytest.raises(DataError): + r.pubsub(port=7000) # Missing host + + @skip_if_server_version_lt("7.0.0") + async def test_pubsub_without_specifying_node(self, r): + """ + Test creation of pubsub instance without specifying a node. The node + should be determined based on the keyslot of the first command + execution. + """ + channel_name = "foo" + node = r.get_node_from_key(channel_name) + p = r.pubsub() + try: + assert p.get_pubsub_node() is None + await p.subscribe(channel_name) + assert p.get_pubsub_node() == node + finally: + await p.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_get_pubsub_node(self, r): + """Test get_pubsub_node returns the correct node""" + nodes = r.get_nodes() + if nodes: + node = nodes[0] + p = r.pubsub(node=node) + try: + assert p.get_pubsub_node() == node + finally: + await p.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_get_sharded_message_with_publish(self, r): + """ + Test get_sharded_message returns published messages correctly. + Validates that sharded message retrieval works end-to-end. + """ + pubsub = r.pubsub() + channel = "test-channel:{0}" + + try: + # Subscribe to the shard channel + await pubsub.ssubscribe(channel) + + # Read subscription confirmation using sharded message retrieval + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Publish a message using execute_command directly (spublish not on RedisCluster) + await r.execute_command("SPUBLISH", channel, "test message") + + # Read the published message using get_sharded_message + msg = await pubsub.get_sharded_message(timeout=1.0) + # May need to retry a few times + for _ in range(10): + if msg is not None and msg.get("type") == "smessage": + break + await asyncio.sleep(0.1) + msg = await pubsub.get_sharded_message(timeout=0.1) + + assert msg is not None + assert msg["type"] == "smessage" + assert msg["channel"] == channel.encode() + assert msg["data"] == b"test message" + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_get_sharded_message_multiple_channels(self, r): + """ + Test get_sharded_message with multiple channels on potentially different nodes. + Validates round-robin message retrieval across nodes. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + + try: + # Subscribe to both channels + await pubsub.ssubscribe(channel1, channel2) + + # Read subscription confirmations using sharded retrieval + for _ in range(2): + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None + + # Publish messages to both channels using execute_command + await r.execute_command("SPUBLISH", channel1, "msg1") + await r.execute_command("SPUBLISH", channel2, "msg2") + + # Read messages using get_sharded_message + messages = [] + for _ in range(10): + msg = await pubsub.get_sharded_message(timeout=0.2) + if msg and msg.get("type") == "smessage": + messages.append(msg) + if len(messages) >= 2: + break + + assert len(messages) == 2 + channels_received = {msg["channel"] for msg in messages} + assert channel1.encode() in channels_received + assert channel2.encode() in channels_received + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_get_sharded_message_timeout_returns_none(self, r): + """ + Test that get_sharded_message with timeout returns None when no message + arrives within the timeout period. + """ + pubsub = r.pubsub() + channel = "test-channel:{0}" + + try: + await pubsub.ssubscribe(channel) + # Read subscription confirmation using sharded retrieval + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None + + # Call get_sharded_message with a short timeout - should return None + import time + + start = time.monotonic() + msg = await pubsub.get_sharded_message(timeout=0.1) + elapsed = time.monotonic() - start + + assert msg is None + # Verify timeout was approximately respected (allow some slack) + assert elapsed < 0.5 + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_get_sharded_message_timeout_zero_returns_immediately(self, r): + """ + Test that get_sharded_message(timeout=0) returns immediately without blocking. + """ + pubsub = r.pubsub() + channel = "test-channel:{0}" + + try: + await pubsub.ssubscribe(channel) + # Read subscription confirmation using sharded retrieval + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None + + # get_sharded_message with timeout=0 should return immediately + import time + + start = time.monotonic() + msg = await pubsub.get_sharded_message(timeout=0) + elapsed = time.monotonic() - start + + assert msg is None + assert elapsed < 0.1 + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_generator_handles_concurrent_mapping_changes(self, r): + """ + Test that the generator properly handles mapping changes during iteration. + This validates the fix for RuntimeError: dictionary changed size during iteration. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + + try: + # Subscribe to first channel + await pubsub.ssubscribe(channel1) + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None + + # Get initial mapping size (cluster pubsub only) + assert hasattr(pubsub, "node_pubsub_mapping"), "Test requires ClusterPubSub" + initial_size = len(pubsub.node_pubsub_mapping) + + # Subscribe to second channel (modifies mapping during potential iteration) + await pubsub.ssubscribe(channel2) + msg = await self.wait_for_message(pubsub, timeout=1.0, sharded=True) + assert msg is not None + + # Verify mapping was updated + assert len(pubsub.node_pubsub_mapping) >= initial_size + + # Publish and read messages - should not raise RuntimeError + await r.execute_command("SPUBLISH", channel1, "msg1") + await r.execute_command("SPUBLISH", channel2, "msg2") + + messages_received = 0 + for _ in range(10): + msg = await pubsub.get_sharded_message(timeout=0.2) + if msg and msg.get("type") == "smessage": + messages_received += 1 + if messages_received >= 2: + break + + assert messages_received == 2 + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_get_redis_connection(self, r): + """ + Test that get_redis_connection() returns the pubsub's dedicated + connection after subscribing to a channel. + """ + node = r.get_default_node() + p = r.pubsub(node=node) + try: + # Before subscribing, connection should be None + assert p.get_redis_connection() is None + + # Subscribe to establish the dedicated pubsub connection + await p.subscribe("test-channel") + + # Now get_redis_connection() should return the dedicated connection + connection = p.get_redis_connection() + assert connection is not None + # The connection should be from the node + assert connection.host == node.host + assert connection.port == node.port + finally: + await p.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_init_pubsub_with_non_existent_node(self, r): + """ + Test creation of pubsub instance with node that doesn't exist in the + cluster. RedisClusterException should be raised. + """ + from redis.cluster import ClusterNode + + node = ClusterNode("1.1.1.1", 1111) + with pytest.raises(RedisClusterException): + r.pubsub(node=node) + + @skip_if_server_version_lt("7.0.0") + async def test_pubsub_channels_merge_results(self, r): + """ + Test that pubsub_channels merges results from all nodes. + """ + nodes = r.get_nodes() + channels = [] + pubsub_nodes = [] + i = 0 + try: + for node in nodes: + channel = f"foo{i}" + # Create pubsub clients connected to different nodes + p = r.pubsub(node=node) + pubsub_nodes.append(p) + await p.subscribe(channel) + b_channel = channel.encode("utf-8") + channels.append(b_channel) + # Read subscription confirmation + await self.wait_for_message(p, timeout=1.0) + i += 1 + + # Assert that the cluster's pubsub_channels function returns ALL channels + await asyncio.sleep(0.3) # Allow time for subscriptions to propagate + result = await r.pubsub_channels(target_nodes="all") + result.sort() + channels.sort() + assert result == channels + finally: + for p in pubsub_nodes: + await p.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_pubsub_numsub_merge_results(self, r): + """ + Test that pubsub_numsub merges subscription counts from all nodes. + """ + nodes = r.get_nodes() + pubsub_nodes = [] + channel = "foo" + b_channel = channel.encode("utf-8") + try: + for node in nodes: + # Create pubsub clients connected to different nodes + p = r.pubsub(node=node) + pubsub_nodes.append(p) + await p.subscribe(channel) + # Read subscription confirmation + await self.wait_for_message(p, timeout=1.0) + + # Assert cluster's pubsub_numsub returns ALL clients + await asyncio.sleep(0.3) # Allow time for subscriptions to propagate + result = await r.pubsub_numsub(channel, target_nodes="all") + assert result == [(b_channel, len(nodes))] + finally: + for p in pubsub_nodes: + await p.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_pubsub_numpat_merge_results(self, r): + """ + Test that pubsub_numpat merges pattern subscription counts from all nodes. + """ + nodes = r.get_nodes() + pubsub_nodes = [] + pattern = "foo*" + try: + for node in nodes: + # Create pubsub clients connected to different nodes + p = r.pubsub(node=node) + pubsub_nodes.append(p) + await p.psubscribe(pattern) + # Read subscription confirmation + await self.wait_for_message(p, timeout=1.0) + + # Assert cluster's pubsub_numpat returns ALL pattern subscriptions + await asyncio.sleep(0.3) # Allow time for subscriptions to propagate + result = await r.pubsub_numpat(target_nodes="all") + assert result == len(nodes) + finally: + for p in pubsub_nodes: + await p.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_message_handler_with_message(self, r): + """ + Test shard channel subscription with message handler receives actual messages. + This verifies the handler callback is executed when messages arrive. + """ + pubsub = r.pubsub(ignore_subscribe_messages=True) + channel = "test-handler-channel:{0}" + received_messages = [] + + def message_handler(message): + received_messages.append(message) + + try: + # Subscribe with handler + await pubsub.ssubscribe(**{channel: message_handler}) + + # Publish a message + await r.spublish(channel, "test message") + + # Read message using get_sharded_message - handler should be called + for _ in range(10): + msg = await pubsub.get_sharded_message(timeout=0.2) + # When handler is set, get_sharded_message returns None + # but the handler receives the message + if received_messages: + break + + # Verify handler received the message + assert msg is None + assert len(received_messages) == 1 + assert received_messages[0]["type"] == "smessage" + assert received_messages[0]["channel"] == channel.encode() + assert received_messages[0]["data"] == b"test message" + finally: + await pubsub.aclose() diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index f0101d7797..53600a76c2 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -73,6 +73,15 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.unsubscribe, "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } + elif type == "shard_channel": + return { + "p": pubsub, + "sub_type": "ssubscribe", + "unsub_type": "sunsubscribe", + "sub_func": pubsub.ssubscribe, + "unsub_func": pubsub.sunsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } elif type == "pattern": return { "p": pubsub, @@ -120,6 +129,12 @@ async def test_pattern_subscribe_unsubscribe(self, pubsub): kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_subscribe_unsubscribe(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "shard_channel") + await self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster async def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys @@ -265,6 +280,12 @@ async def test_subscribe_property_with_patterns(self, pubsub): kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribed_property(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + async def test_subscribe_property_with_shard_channels(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "shard_channel") + await self._test_subscribed_property(**kwargs) + async def test_aclosing(self, r: redis.Redis): p = r.pubsub() async with aclosing(p): @@ -485,8 +506,7 @@ async def test_unicode_channel_message_handler(self, r: redis.Redis): await p.aclose() @pytest.mark.onlynoncluster - # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html - # #known-limitations-with-pubsub + # see: https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations async def test_unicode_pattern_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) pattern = "uni" + chr(4456) + "*" diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index e128a11aaf..6ce050acb5 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -628,8 +628,7 @@ def test_unicode_shard_channel_message_handler(self, r): assert self.message == make_message("smessage", channel, "test message") @pytest.mark.onlynoncluster - # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html - # #known-limitations-with-pubsub + # see: https://redis.readthedocs.io/en/stable/clustering.html#known-pubsub-limitations def test_unicode_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) pattern = "uni" + chr(4456) + "*"