1414
1515"""DB-API Connection for the Google Cloud Spanner."""
1616
17+ import time
1718import warnings
1819
20+ from google .api_core .exceptions import Aborted
1921from google .api_core .gapic_v1 .client_info import ClientInfo
2022from google .cloud import spanner_v1 as spanner
23+ from google .cloud .spanner_v1 .session import _get_retry_delay
2124
25+ from google .cloud .spanner_dbapi .checksum import _compare_checksums
26+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
2227from google .cloud .spanner_dbapi .cursor import Cursor
2328from google .cloud .spanner_dbapi .exceptions import InterfaceError
2429from google .cloud .spanner_dbapi .version import DEFAULT_USER_AGENT
2530from google .cloud .spanner_dbapi .version import PY_VERSION
2631
2732
2833AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
34+ MAX_INTERNAL_RETRIES = 50
2935
3036
3137class Connection :
@@ -48,9 +54,16 @@ def __init__(self, instance, database):
4854
4955 self ._transaction = None
5056 self ._session = None
57+ # SQL statements, which were executed
58+ # within the current transaction
59+ self ._statements = []
5160
5261 self .is_closed = False
5362 self ._autocommit = False
63+ # indicator to know if the session pool used by
64+ # this connection should be cleared on the
65+ # connection close
66+ self ._own_pool = True
5467
5568 @property
5669 def autocommit (self ):
@@ -114,6 +127,58 @@ def _release_session(self):
114127 self .database ._pool .put (self ._session )
115128 self ._session = None
116129
130+ def retry_transaction (self ):
131+ """Retry the aborted transaction.
132+
133+ All the statements executed in the original transaction
134+ will be re-executed in new one. Results checksums of the
135+ original statements and the retried ones will be compared.
136+
137+ :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
138+ If results checksum of the retried statement is
139+ not equal to the checksum of the original one.
140+ """
141+ attempt = 0
142+ while True :
143+ self ._transaction = None
144+ attempt += 1
145+ if attempt > MAX_INTERNAL_RETRIES :
146+ raise
147+
148+ try :
149+ self ._rerun_previous_statements ()
150+ break
151+ except Aborted as exc :
152+ delay = _get_retry_delay (exc .errors [0 ], attempt )
153+ if delay :
154+ time .sleep (delay )
155+
156+ def _rerun_previous_statements (self ):
157+ """
158+ Helper to run all the remembered statements
159+ from the last transaction.
160+ """
161+ for statement in self ._statements :
162+ res_iter , retried_checksum = self .run_statement (statement , retried = True )
163+ # executing all the completed statements
164+ if statement != self ._statements [- 1 ]:
165+ for res in res_iter :
166+ retried_checksum .consume_result (res )
167+
168+ _compare_checksums (statement .checksum , retried_checksum )
169+ # executing the failed statement
170+ else :
171+ # streaming up to the failed result or
172+ # to the end of the streaming iterator
173+ while len (retried_checksum ) < len (statement .checksum ):
174+ try :
175+ res = next (iter (res_iter ))
176+ retried_checksum .consume_result (res )
177+ except StopIteration :
178+ break
179+
180+ _compare_checksums (statement .checksum , retried_checksum )
181+
117182 def transaction_checkout (self ):
118183 """Get a Cloud Spanner transaction.
119184
@@ -158,6 +223,9 @@ def close(self):
158223 ):
159224 self ._transaction .rollback ()
160225
226+ if self ._own_pool :
227+ self .database ._pool .clear ()
228+
161229 self .is_closed = True
162230
163231 def commit (self ):
@@ -168,8 +236,13 @@ def commit(self):
168236 if self ._autocommit :
169237 warnings .warn (AUTOCOMMIT_MODE_WARNING , UserWarning , stacklevel = 2 )
170238 elif self ._transaction :
171- self ._transaction .commit ()
172- self ._release_session ()
239+ try :
240+ self ._transaction .commit ()
241+ self ._release_session ()
242+ self ._statements = []
243+ except Aborted :
244+ self .retry_transaction ()
245+ self .commit ()
173246
174247 def rollback (self ):
175248 """Rolls back any pending transaction.
@@ -182,6 +255,7 @@ def rollback(self):
182255 elif self ._transaction :
183256 self ._transaction .rollback ()
184257 self ._release_session ()
258+ self ._statements = []
185259
186260 def cursor (self ):
187261 """Factory to create a DB-API Cursor."""
@@ -198,6 +272,32 @@ def run_prior_DDL_statements(self):
198272
199273 return self .database .update_ddl (ddl_statements ).result ()
200274
275+ def run_statement (self , statement , retried = False ):
276+ """Run single SQL statement in begun transaction.
277+
278+ This method is never used in autocommit mode. In
279+ !autocommit mode however it remembers every executed
280+ SQL statement with its parameters.
281+
282+ :type statement: :class:`dict`
283+ :param statement: SQL statement to execute.
284+
285+ :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
286+ :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
287+ :returns: Streamed result set of the statement and a
288+ checksum of this statement results.
289+ """
290+ transaction = self .transaction_checkout ()
291+ if not retried :
292+ self ._statements .append (statement )
293+
294+ return (
295+ transaction .execute_sql (
296+ statement .sql , statement .params , param_types = statement .param_types ,
297+ ),
298+ ResultsChecksum () if retried else statement .checksum ,
299+ )
300+
201301 def __enter__ (self ):
202302 return self
203303
@@ -207,7 +307,12 @@ def __exit__(self, etype, value, traceback):
207307
208308
209309def connect (
210- instance_id , database_id , project = None , credentials = None , pool = None , user_agent = None
310+ instance_id ,
311+ database_id ,
312+ project = None ,
313+ credentials = None ,
314+ pool = None ,
315+ user_agent = None ,
211316):
212317 """Creates a connection to a Google Cloud Spanner database.
213318
@@ -261,4 +366,8 @@ def connect(
261366 if not database .exists ():
262367 raise ValueError ("database '%s' does not exist." % database_id )
263368
264- return Connection (instance , database )
369+ conn = Connection (instance , database )
370+ if pool is not None :
371+ conn ._own_pool = False
372+
373+ return conn
0 commit comments