| 
1 | 1 | import asyncio  | 
 | 2 | +import functools  | 
2 | 3 | import sys  | 
3 | 4 | from typing import Optional  | 
4 | 5 | 
 
  | 
 | 
20 | 21 | pytestmark = pytest.mark.asyncio(forbid_global_loop=True)  | 
21 | 22 | 
 
  | 
22 | 23 | 
 
  | 
 | 24 | +def with_timeout(t):  | 
 | 25 | + def wrapper(corofunc):  | 
 | 26 | + @functools.wraps(corofunc)  | 
 | 27 | + async def run(*args, **kwargs):  | 
 | 28 | + async with async_timeout.timeout(t):  | 
 | 29 | + return await corofunc(*args, **kwargs)  | 
 | 30 | + | 
 | 31 | + return run  | 
 | 32 | + | 
 | 33 | + return wrapper  | 
 | 34 | + | 
 | 35 | + | 
23 | 36 | async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):  | 
24 | 37 |  now = asyncio.get_event_loop().time()  | 
25 | 38 |  timeout = now + timeout  | 
@@ -603,6 +616,75 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis):  | 
603 | 616 |  assert await p.get_message(timeout=0.01) is None  | 
604 | 617 | 
 
  | 
605 | 618 | 
 
  | 
 | 619 | +@pytest.mark.onlynoncluster  | 
 | 620 | +class TestPubSubReconnect:  | 
 | 621 | + # @pytest.mark.xfail  | 
 | 622 | + @with_timeout(2)  | 
 | 623 | + async def test_reconnect_listen(self, r: redis.Redis):  | 
 | 624 | + """  | 
 | 625 | + Test that a loop processing PubSub messages can survive  | 
 | 626 | + a disconnect, by issuing a connect() call.  | 
 | 627 | + """  | 
 | 628 | + messages = asyncio.Queue()  | 
 | 629 | + pubsub = r.pubsub()  | 
 | 630 | + interrupt = False  | 
 | 631 | + | 
 | 632 | + async def loop():  | 
 | 633 | + # must make sure the task exits  | 
 | 634 | + async with async_timeout.timeout(2):  | 
 | 635 | + nonlocal interrupt  | 
 | 636 | + await pubsub.subscribe("foo")  | 
 | 637 | + while True:  | 
 | 638 | + # print("loop")  | 
 | 639 | + try:  | 
 | 640 | + try:  | 
 | 641 | + await pubsub.connect()  | 
 | 642 | + await loop_step()  | 
 | 643 | + # print("succ")  | 
 | 644 | + except redis.ConnectionError:  | 
 | 645 | + await asyncio.sleep(0.1)  | 
 | 646 | + except asyncio.CancelledError:  | 
 | 647 | + # we use a cancel to interrupt the "listen"  | 
 | 648 | + # when we perform a disconnect  | 
 | 649 | + # print("cancel", interrupt)  | 
 | 650 | + if interrupt:  | 
 | 651 | + interrupt = False  | 
 | 652 | + else:  | 
 | 653 | + raise  | 
 | 654 | + | 
 | 655 | + async def loop_step():  | 
 | 656 | + # get a single message via listen()  | 
 | 657 | + async for message in pubsub.listen():  | 
 | 658 | + await messages.put(message)  | 
 | 659 | + break  | 
 | 660 | + | 
 | 661 | + task = asyncio.get_event_loop().create_task(loop())  | 
 | 662 | + # get the initial connect message  | 
 | 663 | + async with async_timeout.timeout(1):  | 
 | 664 | + message = await messages.get()  | 
 | 665 | + assert message == {  | 
 | 666 | + "channel": b"foo",  | 
 | 667 | + "data": 1,  | 
 | 668 | + "pattern": None,  | 
 | 669 | + "type": "subscribe",  | 
 | 670 | + }  | 
 | 671 | + # now, disconnect the connection.  | 
 | 672 | + await pubsub.connection.disconnect()  | 
 | 673 | + interrupt = True  | 
 | 674 | + task.cancel() # interrupt the listen call  | 
 | 675 | + # await another auto-connect message  | 
 | 676 | + message = await messages.get()  | 
 | 677 | + assert message == {  | 
 | 678 | + "channel": b"foo",  | 
 | 679 | + "data": 1,  | 
 | 680 | + "pattern": None,  | 
 | 681 | + "type": "subscribe",  | 
 | 682 | + }  | 
 | 683 | + task.cancel()  | 
 | 684 | + with pytest.raises(asyncio.CancelledError):  | 
 | 685 | + await task  | 
 | 686 | + | 
 | 687 | + | 
606 | 688 | @pytest.mark.onlynoncluster  | 
607 | 689 | class TestPubSubRun:  | 
608 | 690 |  async def _subscribe(self, p, *args, **kwargs):  | 
 | 
0 commit comments