|  | 
| 1 | 1 | import asyncio | 
|  | 2 | +import socket | 
| 2 | 3 | import types | 
|  | 4 | +from unittest.mock import patch | 
| 3 | 5 | 
 | 
| 4 | 6 | import pytest | 
| 5 | 7 | 
 | 
| 6 |  | -from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection | 
| 7 |  | -from redis.exceptions import InvalidResponse | 
|  | 8 | +from redis.asyncio.connection import ( | 
|  | 9 | + Connection, | 
|  | 10 | + PythonParser, | 
|  | 11 | + UnixDomainSocketConnection, | 
|  | 12 | +) | 
|  | 13 | +from redis.asyncio.retry import Retry | 
|  | 14 | +from redis.backoff import NoBackoff | 
|  | 15 | +from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError | 
| 8 | 16 | from redis.utils import HIREDIS_AVAILABLE | 
| 9 | 17 | from tests.conftest import skip_if_server_version_lt | 
| 10 | 18 | 
 | 
| @@ -60,3 +68,44 @@ async def test_socket_param_regression(r): | 
| 60 | 68 | async def test_can_run_concurrent_commands(r): | 
| 61 | 69 |  assert await r.ping() is True | 
| 62 | 70 |  assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +async def test_connect_retry_on_timeout_error(): | 
|  | 74 | + """Test that the _connect function is retried in case of a timeout""" | 
|  | 75 | + conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) | 
|  | 76 | + origin_connect = conn._connect | 
|  | 77 | + conn._connect = mock.AsyncMock() | 
|  | 78 | + | 
|  | 79 | + async def mock_connect(): | 
|  | 80 | + # connect only on the last retry | 
|  | 81 | + if conn._connect.call_count <= 2: | 
|  | 82 | + raise socket.timeout | 
|  | 83 | + else: | 
|  | 84 | + return await origin_connect() | 
|  | 85 | + | 
|  | 86 | + conn._connect.side_effect = mock_connect | 
|  | 87 | + await conn.connect() | 
|  | 88 | + assert conn._connect.call_count == 3 | 
|  | 89 | + | 
|  | 90 | + | 
|  | 91 | +async def test_connect_without_retry_on_os_error(): | 
|  | 92 | + """Test that the _connect function is not being retried in case of a OSError""" | 
|  | 93 | + with patch.object(Connection, "_connect") as _connect: | 
|  | 94 | + _connect.side_effect = OSError("") | 
|  | 95 | + conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2)) | 
|  | 96 | + with pytest.raises(ConnectionError): | 
|  | 97 | + await conn.connect() | 
|  | 98 | + assert _connect.call_count == 1 | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +async def test_connect_timeout_error_without_retry(): | 
|  | 102 | + """Test that the _connect function is not being retried if retry_on_timeout is | 
|  | 103 | + set to False""" | 
|  | 104 | + conn = Connection(retry_on_timeout=False) | 
|  | 105 | + conn._connect = mock.AsyncMock() | 
|  | 106 | + conn._connect.side_effect = socket.timeout | 
|  | 107 | + | 
|  | 108 | + with pytest.raises(TimeoutError) as e: | 
|  | 109 | + await conn.connect() | 
|  | 110 | + assert conn._connect.call_count == 1 | 
|  | 111 | + assert str(e.value) == "Timeout connecting to server" | 
0 commit comments