Skip to content
Prev Previous commit
Next Next commit
Shrded Pubsub TestPubSubSubscribeUnsubscribe
  • Loading branch information
dvora-h committed May 22, 2023
commit c6c8a0485f49c110fc8a07cf876e9a1046de1ede
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ pytest-cov>=4.0.0
vulture>=2.3.0
ujson>=4.2.0
wheel>=0.30.0
urllib3<2
uvloop
52 changes: 36 additions & 16 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port):
f"Node {host}:{port} doesn't exist in the cluster"
)

def execute_command(self, *args, ):
def execute_command(self, *args):
"""
Execute a subscribe/unsubscribe command.

Expand Down Expand Up @@ -1723,36 +1723,41 @@ def _get_node_pubsub(self, node):
pubsub = node.redis_connection.pubsub()
self.node_pubsub_mapping[node.name] = pubsub
return pubsub
def _sharded_message_generator(self, ignore_subscribe_messages=False):
while True:

def _sharded_message_generator(self):
for _ in range(len(self.node_pubsub_mapping)):
pubsub = next(self._pubsubs_generator)
message = pubsub.get_message(ignore_subscribe_messages=ignore_subscribe_messages)
message = pubsub.get_message()
if message is not None:
return message
return None

def _pubsubs_generator(self):
while True:
for pubsub in self.node_pubsub_mapping.values():
yield pubsub

def get_sharded_message(
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
):
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
):
if target_node:
message = self.node_pubsub_mapping[target_node.name].get_message(
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
)
else:
message = self._sharded_message_generator(ignore_subscribe_messages=ignore_subscribe_messages)
message = self._sharded_message_generator()
if message is None:
return None
elif str_if_bytes(message["type"]) == "sunsubscribe":
self.shard_channels.pop(message["channel"], None)
if message["channel"] in self.pending_unsubscribe_shard_channels:
self.pending_unsubscribe_shard_channels.remove(message["channel"])
self.shard_channels.pop(message["channel"], None)
if not self.channels and not self.patterns and not self.shard_channels:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()
if self.ignore_subscribe_messages or ignore_subscribe_messages:
return None
return message

def ssubscribe(self, *args, **kwargs):
Expand All @@ -1762,12 +1767,14 @@ def ssubscribe(self, *args, **kwargs):
node = self.cluster.get_node_from_key(s_channel)
pubsub = self._get_node_pubsub(node)
pubsub.ssubscribe(s_channel)
# self.subscribed = self.subscribed or self._get_node_pubsub(node).subscribed
self.shard_channels.update(pubsub.shard_channels)
self.pending_unsubscribe_shard_channels.difference_update(
self._normalize_keys({s_channel: None})
)
if pubsub.subscribed and not self.subscribed:
self.subscribed_event.set()
self.health_check_response_counter = 0

def sunsubscribe(self, *args):
if args:
args = list_or_args(args[0], args[1:])
Expand All @@ -1776,7 +1783,11 @@ def sunsubscribe(self, *args):

for s_channel in args:
node = self.cluster.get_node_from_key(s_channel)
self._get_node_pubsub(node).sunsubscribe(s_channel)
p = self._get_node_pubsub(node)
p.sunsubscribe(s_channel)
self.pending_unsubscribe_shard_channels.update(
p.pending_unsubscribe_shard_channels
)

def get_redis_connection(self):
"""
Expand All @@ -1785,6 +1796,15 @@ def get_redis_connection(self):
if self.node is not None:
return self.node.redis_connection

def disconnect(self):
"""
Disconnect the pubsub connection.
"""
if self.connection:
self.connection.disconnect()
for pubsub in self.node_pubsub_mapping.values():
pubsub.connection.disconnect()


class ClusterPipeline(RedisCluster):
"""
Expand Down
4 changes: 1 addition & 3 deletions redis/parsers/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def _read_from_socket(
sock.settimeout(self.socket_timeout)

def can_read(self, timeout: float) -> bool:
read = self._read_from_socket(
timeout=timeout, raise_on_timeout=False
)
read = self._read_from_socket(timeout=timeout, raise_on_timeout=False)
_bytes = bool(self.unread_bytes())
return _bytes or read

Expand Down
164 changes: 139 additions & 25 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import socket
import threading
import time
from collections import defaultdict
from unittest import mock
from unittest.mock import patch

Expand All @@ -20,14 +21,18 @@
)


def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None):
def wait_for_message(
pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None, func=None
):
now = time.time()
timeout = now + timeout
while now < timeout:
if node:
message = pubsub.get_sharded_message(
ignore_subscribe_messages=ignore_subscribe_messages, target_node=node
)
elif func:
message = func(ignore_subscribe_messages=ignore_subscribe_messages)
else:
message = pubsub.get_message(
ignore_subscribe_messages=ignore_subscribe_messages
Expand Down Expand Up @@ -116,30 +121,34 @@ def test_shard_channel_subscribe_unsubscribe(self, r):
@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_shard_channel_subscribe_unsubscribe_cluster(self, r):
node_channels = defaultdict(int)
p = r.pubsub()
keys = {
"foo": r.get_node_from_key("foo"),
"bar": r.get_node_from_key("bar"),
"uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"),
}

for key in keys.keys():
for key, node in keys.items():
assert p.ssubscribe(key) is None
# should be a message for each channel/pattern we just subscribed to
data = [1, 1, 2]
for i, (key, node) in enumerate(keys.items()):
assert wait_for_message(p, node=node) == make_message("ssubscribe", key, data[i])

# should be a message for each shard_channel we just subscribed to
for key, node in keys.items():
node_channels[node.name] += 1
assert wait_for_message(p, node=node) == make_message(
"ssubscribe", key, node_channels[node.name]
)

for key in keys.keys():
assert p.sunsubscribe(key) is None

# should be a message for each channel/pattern we just unsubscribed
# should be a message for each shard_channel we just unsubscribed
# from
data = [0, 1, 0]
breakpoint()
for i, (key, node) in enumerate(keys.items()):
assert wait_for_message(p, node=node) == make_message("sunsubscribe", key, data[i])
breakpoint()
for key, node in keys.items():
node_channels[node.name] -= 1
assert wait_for_message(p, node=node) == make_message(
"sunsubscribe", key, node_channels[node.name]
)

def _test_resubscribe_on_reconnection(
self, p, sub_type, unsub_type, sub_func, unsub_func, keys
Expand All @@ -154,7 +163,7 @@ def _test_resubscribe_on_reconnection(

# manually disconnect
p.connection.disconnect()

# breakpoint()
# calling get_message again reconnects and resubscribes
# note, we may not re-subscribe to channels in exactly the same order
# so we have to do some extra checks to make sure we got them all
Expand Down Expand Up @@ -252,38 +261,103 @@ def test_subscribe_property_with_shard_channels(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel")
self._test_subscribed_property(**kwargs)

@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_subscribe_property_with_shard_channels_cluster(self, r):
p = r.pubsub()
keys = ["foo", "bar", "uni" + chr(4456) + "code"]
nodes = [r.get_node_from_key(key) for key in keys]
assert p.subscribed is False
p.ssubscribe(keys[0])
# we're now subscribed even though we haven't processed the
# reply from the server just yet
assert p.subscribed is True
assert wait_for_message(p, node=nodes[0]) == make_message(
"ssubscribe", keys[0], 1
)
# we're still subscribed
assert p.subscribed is True

# unsubscribe from all shard_channels
p.sunsubscribe()
# we're still technically subscribed until we process the
# response messages from the server
assert p.subscribed is True
assert wait_for_message(p, node=nodes[0]) == make_message(
"sunsubscribe", keys[0], 0
)
# now we're no longer subscribed as no more messages can be delivered
# to any channels we were listening to
assert p.subscribed is False

# subscribing again flips the flag back
p.ssubscribe(keys[0])
assert p.subscribed is True
assert wait_for_message(p, node=nodes[0]) == make_message(
"ssubscribe", keys[0], 1
)

# unsubscribe again
p.sunsubscribe()
assert p.subscribed is True
# subscribe to another shard_channel before reading the unsubscribe response
p.ssubscribe(keys[1])
assert p.subscribed is True
# read the unsubscribe for key1
assert wait_for_message(p, node=nodes[0]) == make_message(
"sunsubscribe", keys[0], 0
)
# we're still subscribed to key2, so subscribed should still be True
assert p.subscribed is True
# read the key2 subscribe message
assert wait_for_message(p, node=nodes[1]) == make_message(
"ssubscribe", keys[1], 1
)
p.sunsubscribe()
# haven't read the message yet, so we're still subscribed
assert p.subscribed is True
assert wait_for_message(p, node=nodes[1]) == make_message(
"sunsubscribe", keys[1], 0
)
# now we're finally unsubscribed
assert p.subscribed is False

def test_ignore_all_subscribe_messages(self, r):
p = r.pubsub(ignore_subscribe_messages=True)

checks = (
(p.subscribe, "foo"),
(p.unsubscribe, "foo"),
(p.psubscribe, "f*"),
(p.punsubscribe, "f*"),
(p.subscribe, "foo", p.get_message),
(p.unsubscribe, "foo", p.get_message),
(p.psubscribe, "f*", p.get_message),
(p.punsubscribe, "f*", p.get_message),
(p.ssubscribe, "foo", p.get_sharded_message),
(p.sunsubscribe, "foo", p.get_sharded_message),
)

assert p.subscribed is False
for func, channel in checks:
for func, channel, get_func in checks:
assert func(channel) is None
assert p.subscribed is True
assert wait_for_message(p) is None
assert wait_for_message(p, func=get_func) is None
assert p.subscribed is False

def test_ignore_individual_subscribe_messages(self, r):
p = r.pubsub()

checks = (
(p.subscribe, "foo"),
(p.unsubscribe, "foo"),
(p.psubscribe, "f*"),
(p.punsubscribe, "f*"),
(p.subscribe, "foo", p.get_message),
(p.unsubscribe, "foo", p.get_message),
(p.psubscribe, "f*", p.get_message),
(p.punsubscribe, "f*", p.get_message),
(p.ssubscribe, "foo", p.get_sharded_message),
(p.sunsubscribe, "foo", p.get_sharded_message),
)

assert p.subscribed is False
for func, channel in checks:
for func, channel, get_func in checks:
assert func(channel) is None
assert p.subscribed is True
message = wait_for_message(p, ignore_subscribe_messages=True)
message = wait_for_message(p, ignore_subscribe_messages=True, func=get_func)
assert message is None
assert p.subscribed is False

Expand Down Expand Up @@ -316,6 +390,26 @@ def _test_sub_unsub_resub(
assert wait_for_message(p) == make_message(sub_type, key, 1)
assert p.subscribed is True

@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_sub_unsub_resub_shard_channels_cluster(self, r):
p = r.pubsub()
key = "foo"
p.ssubscribe(key)
p.sunsubscribe(key)
p.ssubscribe(key)
assert p.subscribed is True
assert wait_for_message(p, func=p.get_sharded_message) == make_message(
"ssubscribe", key, 1
)
assert wait_for_message(p, func=p.get_sharded_message) == make_message(
"sunsubscribe", key, 0
)
assert wait_for_message(p, func=p.get_sharded_message) == make_message(
"ssubscribe", key, 1
)
assert p.subscribed is True

def test_sub_unsub_all_resub_channels(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), "channel")
self._test_sub_unsub_all_resub(**kwargs)
Expand Down Expand Up @@ -344,6 +438,26 @@ def _test_sub_unsub_all_resub(
assert wait_for_message(p) == make_message(sub_type, key, 1)
assert p.subscribed is True

@pytest.mark.onlycluster
@skip_if_server_version_lt("7.0.0")
def test_sub_unsub_all_resub_shard_channels_cluster(self, r):
p = r.pubsub()
key = "foo"
p.ssubscribe(key)
p.sunsubscribe()
p.ssubscribe(key)
assert p.subscribed is True
assert wait_for_message(p, func=p.get_sharded_message) == make_message(
"ssubscribe", key, 1
)
assert wait_for_message(p, func=p.get_sharded_message) == make_message(
"sunsubscribe", key, 0
)
assert wait_for_message(p, func=p.get_sharded_message) == make_message(
"ssubscribe", key, 1
)
assert p.subscribed is True


class TestPubSubMessages:
def setup_method(self, method):
Expand Down