33from  redis .asyncio .connection  import  Connection , UnixDomainSocketConnection 
44from  redis .asyncio .retry  import  Retry 
55from  redis .backoff  import  AbstractBackoff , NoBackoff 
6- from  redis .exceptions  import  ConnectionError 
6+ from  redis .exceptions  import  ConnectionError ,  TimeoutError 
77
88
99class  BackoffMock (AbstractBackoff ):
@@ -22,23 +22,55 @@ def compute(self, failures):
2222class  TestConnectionConstructorWithRetry :
2323 "Test that the Connection constructors properly handles Retry objects" 
2424
25+  @pytest .mark .parametrize ("Class" , [Connection , UnixDomainSocketConnection ]) 
26+  def  test_retry_on_error_set (self , Class ):
27+  class  CustomError (Exception ):
28+  pass 
29+ 
30+  retry_on_error  =  [ConnectionError , TimeoutError , CustomError ]
31+  c  =  Class (retry_on_error = retry_on_error )
32+  assert  c .retry_on_error  ==  retry_on_error 
33+  assert  isinstance (c .retry , Retry )
34+  assert  c .retry ._retries  ==  1 
35+  assert  c .retry ._supported_errors  ==  tuple (retry_on_error )
36+ 
37+  @pytest .mark .parametrize ("Class" , [Connection , UnixDomainSocketConnection ]) 
38+  def  test_retry_on_error_not_set (self , Class ):
39+  c  =  Class ()
40+  assert  c .retry_on_error  ==  []
41+  assert  isinstance (c .retry , Retry )
42+  assert  c .retry ._retries  ==  0 
43+ 
2544 @pytest .mark .parametrize ("retry_on_timeout" , [False , True ]) 
2645 @pytest .mark .parametrize ("Class" , [Connection , UnixDomainSocketConnection ]) 
27-  def  test_retry_on_timeout_boolean (self , Class , retry_on_timeout ):
46+  def  test_retry_on_timeout (self , Class , retry_on_timeout ):
2847 c  =  Class (retry_on_timeout = retry_on_timeout )
2948 assert  c .retry_on_timeout  ==  retry_on_timeout 
3049 assert  isinstance (c .retry , Retry )
3150 assert  c .retry ._retries  ==  (1  if  retry_on_timeout  else  0 )
3251
3352 @pytest .mark .parametrize ("retries" , range (10 )) 
3453 @pytest .mark .parametrize ("Class" , [Connection , UnixDomainSocketConnection ]) 
35-  def  test_retry_on_timeout_retry (self , Class , retries : int ):
54+  def  test_retry_with_retry_on_timeout (self , Class , retries : int ):
3655 retry_on_timeout  =  retries  >  0 
3756 c  =  Class (retry_on_timeout = retry_on_timeout , retry = Retry (NoBackoff (), retries ))
3857 assert  c .retry_on_timeout  ==  retry_on_timeout 
3958 assert  isinstance (c .retry , Retry )
4059 assert  c .retry ._retries  ==  retries 
4160
61+  @pytest .mark .parametrize ("retries" , range (10 )) 
62+  @pytest .mark .parametrize ("Class" , [Connection , UnixDomainSocketConnection ]) 
63+  def  test_retry_with_retry_on_error (self , Class , retries : int ):
64+  class  CustomError (Exception ):
65+  pass 
66+ 
67+  retry_on_error  =  [ConnectionError , TimeoutError , CustomError ]
68+  c  =  Class (retry_on_error = retry_on_error , retry = Retry (NoBackoff (), retries ))
69+  assert  c .retry_on_error  ==  retry_on_error 
70+  assert  isinstance (c .retry , Retry )
71+  assert  c .retry ._retries  ==  retries 
72+  assert  c .retry ._supported_errors  ==  tuple (retry_on_error )
73+ 
4274
4375class  TestRetry :
4476 "Test that Retry calls backoff and retries the expected number of times" 
0 commit comments