Skip to content
Next Next commit
Add regression tests and fixes for issue #1128
  • Loading branch information
kristjanvalur committed May 7, 2023
commit a1d5a9bb1df84fd0eed1a1397bcf8c4869eb84b5
4 changes: 3 additions & 1 deletion redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
await conn.connect()

read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
response = await self._execute(
conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False
)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down
28 changes: 18 additions & 10 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,11 @@ async def send_packed_command(
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except Exception:
except BaseException:
# BaseExceptions can be raised when a socket send operation is not
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
# to send un-sent data. However, the send_packed_command() API
# does not support it so there is no point in keeping the connection open.
await self.disconnect(nowait=True)
raise

Expand All @@ -827,7 +831,9 @@ async def can_read_destructive(self):
async def read_response(
self,
disable_decoding: bool = False,
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you add this asterisk? I prefer not to break... (if someone is using something like read_response(False, 0.5))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it was an afterghought. I was the one who added the timeout argument originally and in retrospect it probably should have been a keyword-only argument. Since it was my addition, I thought changing it retro-actively to a kw-only would be ok (there was only a two-month inverval between those two changes). "timeout" is traditionally a kw-only arg these days.
I guess one should be careful with these things, so I´m moving the asterisk :)

timeout: Optional[float] = None,
disconnect_on_error: bool = True,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
Expand All @@ -843,22 +849,24 @@ async def read_response(
)
except asyncio.TimeoutError:
if timeout is not None:
# user requested timeout, return None
# user requested timeout, return None. Operation can be retried
return None
# it was a self.socket_timeout error.
await self.disconnect(nowait=True)
if disconnect_on_error:
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
await self.disconnect(nowait=True)
if disconnect_on_error:
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
# need this check for 3.7, where CancelledError
# is subclass of Exception, not BaseException
raise
except Exception:
await self.disconnect(nowait=True)
except BaseException:
# Also by default close in case of BaseException. A lot of code
# relies on this behaviour when doing Command/Response pairs.
# See #1128.
if disconnect_on_error:
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand Down
2 changes: 1 addition & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(disconnect_on_error=False)

response = self._execute(conn, try_read)

Expand Down
24 changes: 18 additions & 6 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,11 @@ def send_packed_command(self, command, check_health=True):
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
except Exception:
except BaseException:
# BaseExceptions can be raised when a socket send operation is not
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
# to send un-sent data. However, the send_packed_command() API
# does not support it so there is no point in keeping the connection open.
self.disconnect()
raise

Expand All @@ -859,23 +863,31 @@ def can_read(self, timeout=0):
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")

def read_response(self, disable_decoding=False):
def read_response(
self, disable_decoding=False, *, disconnect_on_error: bool = True
):
"""Read the response from a previously sent command"""

host_error = self._host_error()

try:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
if disconnect_on_error:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
except OSError as e:
self.disconnect()
if disconnect_on_error:
self.disconnect()
raise ConnectionError(
f"Error while reading from {host_error}" f" : {e.args}"
)
except Exception:
self.disconnect()
except BaseException:
# Also by default close in case of BaseException. A lot of code
# relies on this behaviour when doing Command/Response pairs.
# See #1128.
if disconnect_on_error:
self.disconnect()
raise

if self.health_check_interval:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_asyncio/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Tests async overrides of commands from their mixins
"""
import asyncio
import binascii
import datetime
import re
import sys
from string import ascii_letters

import pytest
Expand All @@ -18,6 +20,11 @@
skip_unless_arch_bits,
)

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is broken as it'll only allow 3.11+, 4.11+, 5.11+, etc.

Always compare against a tuple like this:

Suggested change
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
if sys.version_info >= (3, 11):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we also want to use (3, 11, 3) here based on the following existing code?

# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
if sys.version_info >= (3, 11, 3):
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout

from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout

REDIS_6_VERSION = "5.9.0"


Expand Down Expand Up @@ -3008,6 +3015,37 @@ async def test_module_list(self, r: redis.Redis):
for x in await r.module_list():
assert isinstance(x, dict)

@pytest.mark.onlynoncluster
async def test_interrupted_command(self, r: redis.Redis):
"""
Regression test for issue #1128: An Un-handled BaseException
will leave the socket with un-read response to a previous
command.
"""
ready = asyncio.Event()

async def helper():
with pytest.raises(asyncio.CancelledError):
# blocking pop
ready.set()
await r.brpop(["nonexist"])
# If the following is not done, further Timout operations will fail,
# because the timeout won't catch its Cancelled Error if the task
# has a pending cancel. Python documentation probably should reflect this.
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Suggested change
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
if sys.version_info >= (3, 11):
asyncio.current_task().uncancel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks highly dubious; at least it should never ever be necessary here. Is this a workaround for the stdlib bug fixed in v3.11.3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is necessary in current Python if you decide catch and ignore a Cancel request.
I am the author of pr python/cpython#102815 which fixed the Timeout issue in 3.11.3, and among other things, that PR clarifies this use case. Please see https://github.com/python/cpython/blob/main/Doc/library/asyncio-task.rst#task-cancellation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see now. This is necessary since you're catching CancelledError above.

# if all is well, we can continue. The following should not hang.
await r.set("status", "down")

task = asyncio.create_task(helper())
await ready.wait()
await asyncio.sleep(0.01)
# the task is now sleeping, lets send it an exception
task.cancel()
# If all is well, the task should finish right away, otherwise fail with Timeout
async with async_timeout(0.1):
await task


@pytest.mark.onlynoncluster
class TestBinarySave:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import binascii
import datetime
import re
import threading
import time
from asyncio import CancelledError
from string import ascii_letters
from unittest import mock
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -4741,6 +4744,38 @@ def test_psync(self, r):
res = r2.psync(r2.client_id(), 1)
assert b"FULLRESYNC" in res

@pytest.mark.onlynoncluster
def test_interrupted_command(self, r: redis.Redis):
"""
Regression test for issue #1128: An Un-handled BaseException
will leave the socket with un-read response to a previous
command.
"""

ok = False

def helper():
with pytest.raises(CancelledError):
# blocking pop
with patch.object(
r.connection._parser, "read_response", side_effect=CancelledError
):
r.brpop(["nonexist"])
# if all is well, we can continue.
r.set("status", "down") # should not hang
nonlocal ok
ok = True

thread = threading.Thread(target=helper)
thread.start()
thread.join(0.1)
try:
assert not thread.is_alive()
assert ok
finally:
# disconnect here so that fixture cleanup can proceed
r.connection.disconnect()


@pytest.mark.onlynoncluster
class TestBinarySave:
Expand Down