55import  socket 
66import  threading 
77import  weakref 
8+ from  io  import  SEEK_END 
89from  itertools  import  chain 
910from  queue  import  Empty , Full , LifoQueue 
1011from  time  import  time 
11- from  typing  import  Optional 
12+ from  typing  import  Optional ,  Union 
1213from  urllib .parse  import  parse_qs , unquote , urlparse 
1314
1415from  redis .backoff  import  NoBackoff 
@@ -163,39 +164,47 @@ def parse_error(self, response):
163164
164165
165166class  SocketBuffer :
166-  def  __init__ (self , socket , socket_read_size , socket_timeout ):
167+  def  __init__ (
168+  self , socket : socket .socket , socket_read_size : int , socket_timeout : float 
169+  ):
167170 self ._sock  =  socket 
168171 self .socket_read_size  =  socket_read_size 
169172 self .socket_timeout  =  socket_timeout 
170173 self ._buffer  =  io .BytesIO ()
171-  # number of bytes written to the buffer from the socket 
172-  self .bytes_written  =  0 
173-  # number of bytes read from the buffer 
174-  self .bytes_read  =  0 
175174
176-  @property  
177-  def  length (self ):
178-  return  self .bytes_written  -  self .bytes_read 
175+  def  unread_bytes (self ) ->  int :
176+  """ 
177+  Remaining unread length of buffer 
178+  """ 
179+  pos  =  self ._buffer .tell ()
180+  end  =  self ._buffer .seek (0 , SEEK_END )
181+  self ._buffer .seek (pos )
182+  return  end  -  pos 
179183
180-  def  _read_from_socket (self , length = None , timeout = SENTINEL , raise_on_timeout = True ):
184+  def  _read_from_socket (
185+  self ,
186+  length : Optional [int ] =  None ,
187+  timeout : Union [float , object ] =  SENTINEL ,
188+  raise_on_timeout : Optional [bool ] =  True ,
189+  ) ->  bool :
181190 sock  =  self ._sock 
182191 socket_read_size  =  self .socket_read_size 
183-  buf  =  self ._buffer 
184-  buf .seek (self .bytes_written )
185192 marker  =  0 
186193 custom_timeout  =  timeout  is  not SENTINEL 
187194
195+  buf  =  self ._buffer 
196+  current_pos  =  buf .tell ()
197+  buf .seek (0 , SEEK_END )
198+  if  custom_timeout :
199+  sock .settimeout (timeout )
188200 try :
189-  if  custom_timeout :
190-  sock .settimeout (timeout )
191201 while  True :
192202 data  =  self ._sock .recv (socket_read_size )
193203 # an empty string indicates the server shutdown the socket 
194204 if  isinstance (data , bytes ) and  len (data ) ==  0 :
195205 raise  ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
196206 buf .write (data )
197207 data_length  =  len (data )
198-  self .bytes_written  +=  data_length 
199208 marker  +=  data_length 
200209
201210 if  length  is  not None  and  length  >  marker :
@@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215224 return  False 
216225 raise  ConnectionError (f"Error while reading from socket: { ex .args }  )
217226 finally :
227+  buf .seek (current_pos )
218228 if  custom_timeout :
219229 sock .settimeout (self .socket_timeout )
220230
221-  def  can_read (self , timeout ) :
222-  return  bool (self .length ) or  self ._read_from_socket (
231+  def  can_read (self , timeout :  float )  ->   bool :
232+  return  bool (self .unread_bytes () ) or  self ._read_from_socket (
223233 timeout = timeout , raise_on_timeout = False 
224234 )
225235
226-  def  read (self , length ) :
236+  def  read (self , length :  int )  ->   bytes :
227237 length  =  length  +  2  # make sure to read the \r\n terminator 
228-  # make sure we've read enough data from the socket 
229-  if  length  >  self .length :
230-  self ._read_from_socket (length  -  self .length )
231- 
232-  self ._buffer .seek (self .bytes_read )
238+  # BufferIO will return less than requested if buffer is short 
233239 data  =  self ._buffer .read (length )
234-  self .bytes_read  +=  len (data )
240+  missing  =  length  -  len (data )
241+  if  missing :
242+  # fill up the buffer and read the remainder 
243+  self ._read_from_socket (missing )
244+  data  +=  self ._buffer .read (missing )
235245 return  data [:- 2 ]
236246
237-  def  readline (self ):
247+  def  readline (self )  ->   bytes :
238248 buf  =  self ._buffer 
239-  buf .seek (self .bytes_read )
240249 data  =  buf .readline ()
241250 while  not  data .endswith (SYM_CRLF ):
242251 # there's more data in the socket that we need 
243252 self ._read_from_socket ()
244-  buf .seek (self .bytes_read )
245253 data  =  buf .readline ()
246254
247-  self .bytes_read  +=  len (data )
248255 return  data [:- 2 ]
249256
250-  def  get_pos (self ):
257+  def  get_pos (self )  ->   int :
251258 """ 
252259 Get current read position 
253260 """ 
254-  return  self .bytes_read 
261+  return  self ._buffer . tell () 
255262
256-  def  rewind (self , pos ) :
263+  def  rewind (self , pos :  int )  ->   None :
257264 """ 
258265 Rewind the buffer to a specific position, to re-start reading 
259266 """ 
260-  self .bytes_read   =   pos 
267+  self ._buffer . seek ( pos ) 
261268
262-  def  purge (self ):
269+  def  purge (self )  ->   None :
263270 """ 
264271 After a successful read, purge the read part of buffer 
265272 """ 
266-  unread  =  self .bytes_written   -   self . bytes_read 
273+  unread  =  self .unread_bytes () 
267274
268275 # Only if we have read all of the buffer do we truncate, to 
269276 # reduce the amount of memory thrashing. This heuristic 
@@ -276,13 +283,10 @@ def purge(self):
276283 view  =  self ._buffer .getbuffer ()
277284 view [:unread ] =  view [- unread :]
278285 self ._buffer .truncate (unread )
279-  self .bytes_written  =  unread 
280-  self .bytes_read  =  0 
281286 self ._buffer .seek (0 )
282287
283-  def  close (self ):
288+  def  close (self )  ->   None :
284289 try :
285-  self .bytes_written  =  self .bytes_read  =  0 
286290 self ._buffer .close ()
287291 except  Exception :
288292 # issue #633 suggests the purge/close somehow raised a 
@@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False):
498502 return  response 
499503
500504
505+ DefaultParser : BaseParser 
501506if  HIREDIS_AVAILABLE :
502507 DefaultParser  =  HiredisParser 
503508else :
0 commit comments