|  | 
| 1 | 1 | import asyncio | 
|  | 2 | +import functools | 
| 2 | 3 | import sys | 
| 3 | 4 | from typing import Optional | 
| 4 | 5 | 
 | 
|  | 
| 20 | 21 | pytestmark = pytest.mark.asyncio(forbid_global_loop=True) | 
| 21 | 22 | 
 | 
| 22 | 23 | 
 | 
|  | 24 | +def with_timeout(t): | 
|  | 25 | + def wrapper(corofunc): | 
|  | 26 | + @functools.wraps(corofunc) | 
|  | 27 | + async def run(*args, **kwargs): | 
|  | 28 | + async with async_timeout.timeout(t): | 
|  | 29 | + return await corofunc(*args, **kwargs) | 
|  | 30 | + | 
|  | 31 | + return run | 
|  | 32 | + | 
|  | 33 | + return wrapper | 
|  | 34 | + | 
|  | 35 | + | 
| 23 | 36 | async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): | 
| 24 | 37 |  now = asyncio.get_event_loop().time() | 
| 25 | 38 |  timeout = now + timeout | 
| @@ -603,6 +616,76 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): | 
| 603 | 616 |  assert await p.get_message(timeout=0.01) is None | 
| 604 | 617 | 
 | 
| 605 | 618 | 
 | 
|  | 619 | +@pytest.mark.onlynoncluster | 
|  | 620 | +class TestPubSubReconnect: | 
|  | 621 | + # @pytest.mark.xfail | 
|  | 622 | + @with_timeout(2) | 
|  | 623 | + async def test_reconnect_listen(self, r: redis.Redis): | 
|  | 624 | + """ | 
|  | 625 | + Test that a loop processing PubSub messages can survive | 
|  | 626 | + a disconnect, by issuing a connect() call. | 
|  | 627 | + """ | 
|  | 628 | + messages = asyncio.Queue() | 
|  | 629 | + pubsub = r.pubsub() | 
|  | 630 | + interrupt = False | 
|  | 631 | + | 
|  | 632 | + async def loop(): | 
|  | 633 | + # must make sure the task exits | 
|  | 634 | + async with async_timeout.timeout(2): | 
|  | 635 | + nonlocal interrupt | 
|  | 636 | + await pubsub.subscribe("foo") | 
|  | 637 | + while True: | 
|  | 638 | + # print("loop") | 
|  | 639 | + try: | 
|  | 640 | + try: | 
|  | 641 | + await pubsub.connect() | 
|  | 642 | + await loop_step() | 
|  | 643 | + # print("succ") | 
|  | 644 | + except redis.ConnectionError: | 
|  | 645 | + err = True | 
|  | 646 | + # print("err") | 
|  | 647 | + await asyncio.sleep(0.1) | 
|  | 648 | + except asyncio.CancelledError: | 
|  | 649 | + # we use a cancel to interrupt the "listen" when we perform a disconnect | 
|  | 650 | + # print("cancel", interrupt) | 
|  | 651 | + if interrupt: | 
|  | 652 | + interrupt = False | 
|  | 653 | + else: | 
|  | 654 | + raise | 
|  | 655 | + | 
|  | 656 | + async def loop_step(): | 
|  | 657 | + # get a single message via listen() | 
|  | 658 | + async for message in pubsub.listen(): | 
|  | 659 | + await messages.put(message) | 
|  | 660 | + break | 
|  | 661 | + | 
|  | 662 | + task = asyncio.get_event_loop().create_task(loop()) | 
|  | 663 | + # get the initial connect message | 
|  | 664 | + async with async_timeout.timeout(1): | 
|  | 665 | + message = await messages.get() | 
|  | 666 | + assert message == { | 
|  | 667 | + "channel": b"foo", | 
|  | 668 | + "data": 1, | 
|  | 669 | + "pattern": None, | 
|  | 670 | + "type": "subscribe", | 
|  | 671 | + } | 
|  | 672 | + # now, disconnect the connection. | 
|  | 673 | + await pubsub.connection.disconnect() | 
|  | 674 | + interrupt = True | 
|  | 675 | + task.cancel() # interrupt the listen call | 
|  | 676 | + # await another auto-connect message | 
|  | 677 | + message = await messages.get() | 
|  | 678 | + assert message == { | 
|  | 679 | + "channel": b"foo", | 
|  | 680 | + "data": 1, | 
|  | 681 | + "pattern": None, | 
|  | 682 | + "type": "subscribe", | 
|  | 683 | + } | 
|  | 684 | + task.cancel() | 
|  | 685 | + with pytest.raises(asyncio.CancelledError): | 
|  | 686 | + await task | 
|  | 687 | + | 
|  | 688 | + | 
| 606 | 689 | @pytest.mark.onlynoncluster | 
| 607 | 690 | class TestPubSubRun: | 
| 608 | 691 |  async def _subscribe(self, p, *args, **kwargs): | 
|  | 
0 commit comments