2626)
2727
2828from redis._parsers import AsyncCommandsParser, Encoder
29+ from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy
2930from redis._parsers.helpers import (
3031 _RedisCallbacks,
3132 _RedisCallbacksRESP2,
5152 parse_cluster_slots,
5253)
5354from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
55+ from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver
5456from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
5557from redis.credentials import CredentialProvider
5658from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
@@ -310,6 +312,7 @@ def __init__(
310312 protocol: Optional[int] = 2,
311313 address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
312314 event_dispatcher: Optional[EventDispatcher] = None,
315+ policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(),
313316 ) -> None:
314317 if db:
315318 raise RedisClusterException(
@@ -423,7 +426,36 @@ def __init__(
423426 self.load_balancing_strategy = load_balancing_strategy
424427 self.reinitialize_steps = reinitialize_steps
425428 self.reinitialize_counter = 0
429+
430+ # For backward compatibility, mapping from existing policies to new one
431+ self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = {
432+ self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS,
433+ self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS,
434+ self.__class__.ALL_NODES: RequestPolicy.ALL_NODES,
435+ self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS,
436+ self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE,
437+ SLOT_ID: RequestPolicy.DEFAULT_KEYED,
438+ }
439+
440+ self._policies_callback_mapping: dict[
441+ Union[RequestPolicy, ResponsePolicy], Callable
442+ ] = {
443+ RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [
444+ self.get_random_primary_or_all_nodes(command_name)
445+ ],
446+ RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot,
447+ RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()],
448+ RequestPolicy.ALL_SHARDS: self.get_primaries,
449+ RequestPolicy.ALL_NODES: self.get_nodes,
450+ RequestPolicy.ALL_REPLICAS: self.get_replicas,
451+ RequestPolicy.SPECIAL: self.get_special_nodes,
452+ ResponsePolicy.DEFAULT_KEYLESS: lambda res: res,
453+ ResponsePolicy.DEFAULT_KEYED: lambda res: res,
454+ }
455+
456+ self._policy_resolver = policy_resolver
426457 self.commands_parser = AsyncCommandsParser()
458+ self._aggregate_nodes = None
427459 self.node_flags = self.__class__.NODE_FLAGS.copy()
428460 self.command_flags = self.__class__.COMMAND_FLAGS.copy()
429461 self.response_callbacks = kwargs["response_callbacks"]
@@ -619,6 +651,45 @@ def get_node_from_key(
619651
620652 return slot_cache[node_idx]
621653
654+ def get_random_primary_or_all_nodes(self, command_name):
655+ """
656+ Returns random primary or all nodes depends on READONLY mode.
657+ """
658+ if self.read_from_replicas and command_name in READ_COMMANDS:
659+ return self.get_random_node()
660+
661+ return self.get_random_primary_node()
662+
663+ def get_random_primary_node(self) -> "ClusterNode":
664+ """
665+ Returns a random primary node
666+ """
667+ return random.choice(self.get_primaries())
668+
669+ async def get_nodes_from_slot(self, command: str, *args):
670+ """
671+ Returns a list of nodes that hold the specified keys' slots.
672+ """
673+ # get the node that holds the key's slot
674+ return [
675+ self.nodes_manager.get_node_from_slot(
676+ await self._determine_slot(command, *args),
677+ self.read_from_replicas and command in READ_COMMANDS,
678+ self.load_balancing_strategy if command in READ_COMMANDS else None,
679+ )
680+ ]
681+
682+ def get_special_nodes(self) -> Optional[list["ClusterNode"]]:
683+ """
684+ Returns a list of nodes for commands with a special policy.
685+ """
686+ if not self._aggregate_nodes:
687+ raise RedisClusterException(
688+ "Cannot execute FT.CURSOR commands without FT.AGGREGATE"
689+ )
690+
691+ return self._aggregate_nodes
692+
622693 def keyslot(self, key: EncodableT) -> int:
623694 """
624695 Find the keyslot for a given key.
@@ -643,39 +714,34 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No
643714 self.response_callbacks[command] = callback
644715
645716 async def _determine_nodes(
646- self, command: str, *args: Any, node_flag: Optional[str] = None
717+ self,
718+ command: str,
719+ *args: Any,
720+ request_policy: RequestPolicy,
721+ node_flag: Optional[str] = None,
647722 ) -> List["ClusterNode"]:
648723 # Determine which nodes should be executed the command on.
649724 # Returns a list of target nodes.
650725 if not node_flag:
651726 # get the nodes group for this command if it was predefined
652727 node_flag = self.command_flags.get(command)
653728
654- if node_flag in self.node_flags:
655- if node_flag == self.__class__.DEFAULT_NODE:
656- # return the cluster's default node
657- return [self.nodes_manager.default_node]
658- if node_flag == self.__class__.PRIMARIES:
659- # return all primaries
660- return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
661- if node_flag == self.__class__.REPLICAS:
662- # return all replicas
663- return self.nodes_manager.get_nodes_by_server_type(REPLICA)
664- if node_flag == self.__class__.ALL_NODES:
665- # return all nodes
666- return list(self.nodes_manager.nodes_cache.values())
667- if node_flag == self.__class__.RANDOM:
668- # return a random node
669- return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
729+ if node_flag in self._command_flags_mapping:
730+ request_policy = self._command_flags_mapping[node_flag]
670731
671- # get the node that holds the key's slot
672- return [
673- self.nodes_manager.get_node_from_slot(
674- await self._determine_slot(command, *args),
675- self.read_from_replicas and command in READ_COMMANDS,
676- self.load_balancing_strategy if command in READ_COMMANDS else None,
677- )
678- ]
732+ policy_callback = self._policies_callback_mapping[request_policy]
733+
734+ if request_policy == RequestPolicy.DEFAULT_KEYED:
735+ nodes = await policy_callback(command, *args)
736+ elif request_policy == RequestPolicy.DEFAULT_KEYLESS:
737+ nodes = policy_callback(command)
738+ else:
739+ nodes = policy_callback()
740+
741+ if command.lower() == "ft.aggregate":
742+ self._aggregate_nodes = nodes
743+
744+ return nodes
679745
680746 async def _determine_slot(self, command: str, *args: Any) -> int:
681747 if self.command_flags.get(command) == SLOT_ID:
@@ -780,6 +846,33 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
780846 target_nodes_specified = True
781847 retry_attempts = 0
782848
849+ command_policies = await self._policy_resolver.resolve(args[0].lower())
850+
851+ if not command_policies and not target_nodes_specified:
852+ command_flag = self.command_flags.get(command)
853+ if not command_flag:
854+ # Fallback to default policy
855+ if not self.get_default_node():
856+ slot = None
857+ else:
858+ slot = await self._determine_slot(*args)
859+ if not slot:
860+ command_policies = CommandPolicies()
861+ else:
862+ command_policies = CommandPolicies(
863+ request_policy=RequestPolicy.DEFAULT_KEYED,
864+ response_policy=ResponsePolicy.DEFAULT_KEYED,
865+ )
866+ else:
867+ if command_flag in self._command_flags_mapping:
868+ command_policies = CommandPolicies(
869+ request_policy=self._command_flags_mapping[command_flag]
870+ )
871+ else:
872+ command_policies = CommandPolicies()
873+ elif not command_policies and target_nodes_specified:
874+ command_policies = CommandPolicies()
875+
783876 # Add one for the first execution
784877 execute_attempts = 1 + retry_attempts
785878 for _ in range(execute_attempts):
@@ -795,7 +888,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
795888 if not target_nodes_specified:
796889 # Determine the nodes to execute the command on
797890 target_nodes = await self._determine_nodes(
798- *args, node_flag=passed_targets
891+ *args,
892+ request_policy=command_policies.request_policy,
893+ node_flag=passed_targets,
799894 )
800895 if not target_nodes:
801896 raise RedisClusterException(
@@ -806,10 +901,12 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
806901 # Return the processed result
807902 ret = await self._execute_command(target_nodes[0], *args, **kwargs)
808903 if command in self.result_callbacks:
809- return self.result_callbacks[command](
904+ ret = self.result_callbacks[command](
810905 command, {target_nodes[0].name: ret}, **kwargs
811906 )
812- return ret
907+ return self._policies_callback_mapping[
908+ command_policies.response_policy
909+ ](ret)
813910 else:
814911 keys = [node.name for node in target_nodes]
815912 values = await asyncio.gather(
@@ -824,7 +921,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
824921 return self.result_callbacks[command](
825922 command, dict(zip(keys, values)), **kwargs
826923 )
827- return dict(zip(keys, values))
924+ return self._policies_callback_mapping[
925+ command_policies.response_policy
926+ ](dict(zip(keys, values)))
828927 except Exception as e:
829928 if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
830929 # The nodes and slots cache were should be reinitialized.
@@ -1740,6 +1839,7 @@ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
17401839 self.kwargs = kwargs
17411840 self.position = position
17421841 self.result: Union[Any, Exception] = None
1842+ self.command_policies: Optional[CommandPolicies] = None
17431843
17441844 def __repr__(self) -> str:
17451845 return f"[{self.position}] {self.args} ({self.kwargs})"
@@ -1980,16 +2080,51 @@ async def _execute(
19802080 nodes = {}
19812081 for cmd in todo:
19822082 passed_targets = cmd.kwargs.pop("target_nodes", None)
2083+ command_policies = await client._policy_resolver.resolve(
2084+ cmd.args[0].lower()
2085+ )
2086+
19832087 if passed_targets and not client._is_node_flag(passed_targets):
19842088 target_nodes = client._parse_target_nodes(passed_targets)
2089+
2090+ if not command_policies:
2091+ command_policies = CommandPolicies()
19852092 else:
2093+ if not command_policies:
2094+ command_flag = client.command_flags.get(cmd.args[0])
2095+ if not command_flag:
2096+ # Fallback to default policy
2097+ if not client.get_default_node():
2098+ slot = None
2099+ else:
2100+ slot = await client._determine_slot(*cmd.args)
2101+ if not slot:
2102+ command_policies = CommandPolicies()
2103+ else:
2104+ command_policies = CommandPolicies(
2105+ request_policy=RequestPolicy.DEFAULT_KEYED,
2106+ response_policy=ResponsePolicy.DEFAULT_KEYED,
2107+ )
2108+ else:
2109+ if command_flag in client._command_flags_mapping:
2110+ command_policies = CommandPolicies(
2111+ request_policy=client._command_flags_mapping[
2112+ command_flag
2113+ ]
2114+ )
2115+ else:
2116+ command_policies = CommandPolicies()
2117+
19862118 target_nodes = await client._determine_nodes(
1987- *cmd.args, node_flag=passed_targets
2119+ *cmd.args,
2120+ request_policy=command_policies.request_policy,
2121+ node_flag=passed_targets,
19882122 )
19892123 if not target_nodes:
19902124 raise RedisClusterException(
19912125 f"No targets were found to execute {cmd.args} command on"
19922126 )
2127+ cmd.command_policies = command_policies
19932128 if len(target_nodes) > 1:
19942129 raise RedisClusterException(f"Too many targets for command {cmd.args}")
19952130 node = target_nodes[0]
@@ -2010,9 +2145,9 @@ async def _execute(
20102145 for cmd in todo:
20112146 if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
20122147 try:
2013- cmd.result = await client.execute_command(
2014- * cmd.args, **cmd.kwargs
2015- )
2148+ cmd.result = client._policies_callback_mapping[
2149+ cmd.command_policies.response_policy
2150+ ](await client.execute_command(*cmd.args, **cmd.kwargs) )
20162151 except Exception as e:
20172152 cmd.result = e
20182153
0 commit comments