Skip to content

Commit 15623cd

Browse files
authored
feat: Implementation for Begin and Rollback clientside statements (#1041)
* fix: Refactoring tests to use fixtures properly * Not using autouse fixtures for few tests where not needed * feat: Implementation for Begin and Rollback clientside statements * Incorporating comments * Formatting * Comments incorporated * Fixing tests * Small fix * Test fix as emulator was going OOM
1 parent aa36b07 commit 15623cd

File tree

8 files changed

+824
-733
lines changed

8 files changed

+824
-733
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,30 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
if TYPE_CHECKING:
17+
from google.cloud.spanner_dbapi import Connection
1418
from google.cloud.spanner_dbapi.parsed_statement import (
1519
ParsedStatement,
1620
ClientSideStatementType,
1721
)
1822

1923

20-
def execute(connection, parsed_statement: ParsedStatement):
24+
def execute(connection: "Connection", parsed_statement: ParsedStatement):
2125
"""Executes the client side statements by calling the relevant method.
2226
2327
It is an internal method that can make backwards-incompatible changes.
2428
29+
:type connection: Connection
30+
:param connection: Connection object of the dbApi
31+
2532
:type parsed_statement: ParsedStatement
2633
:param parsed_statement: parsed_statement based on the sql query
2734
"""
2835
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
2936
return connection.commit()
37+
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
38+
return connection.begin()
39+
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
40+
return connection.rollback()

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
ClientSideStatementType,
2121
)
2222

23+
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
2324
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
25+
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
2426

2527

2628
def parse_stmt(query):
@@ -39,4 +41,12 @@ def parse_stmt(query):
3941
return ParsedStatement(
4042
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
4143
)
44+
if RE_BEGIN.match(query):
45+
return ParsedStatement(
46+
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
47+
)
48+
if RE_ROLLBACK.match(query):
49+
return ParsedStatement(
50+
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
51+
)
4252
return None

google/cloud/spanner_dbapi/connection.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from google.rpc.code_pb2 import ABORTED
3535

3636

37-
AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
37+
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
38+
"This method is non-operational as transaction has not started"
39+
)
3840
MAX_INTERNAL_RETRIES = 50
3941

4042

@@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False):
104106
self._read_only = read_only
105107
self._staleness = None
106108
self.request_priority = None
109+
self._transaction_begin_marked = False
107110

108111
@property
109112
def autocommit(self):
@@ -122,7 +125,7 @@ def autocommit(self, value):
122125
:type value: bool
123126
:param value: New autocommit mode state.
124127
"""
125-
if value and not self._autocommit and self.inside_transaction:
128+
if value and not self._autocommit and self._spanner_transaction_started:
126129
self.commit()
127130

128131
self._autocommit = value
@@ -137,17 +140,35 @@ def database(self):
137140
return self._database
138141

139142
@property
140-
def inside_transaction(self):
141-
"""Flag: transaction is started.
143+
def _spanner_transaction_started(self):
144+
"""Flag: whether transaction started at Spanner. This means that we had
145+
made atleast one call to Spanner. Property client_transaction_started
146+
would always be true if this is true as transaction has to start first
147+
at clientside than at Spanner
142148
143149
Returns:
144-
bool: True if transaction begun, False otherwise.
150+
bool: True if Spanner transaction started, False otherwise.
145151
"""
146152
return (
147153
self._transaction
148154
and not self._transaction.committed
149155
and not self._transaction.rolled_back
150-
)
156+
) or (self._snapshot is not None)
157+
158+
@property
159+
def inside_transaction(self):
160+
"""Deprecated property which won't be supported in future versions.
161+
Please use spanner_transaction_started property instead."""
162+
return self._spanner_transaction_started
163+
164+
@property
165+
def _client_transaction_started(self):
166+
"""Flag: whether transaction started at client side.
167+
168+
Returns:
169+
bool: True if transaction started, False otherwise.
170+
"""
171+
return (not self._autocommit) or self._transaction_begin_marked
151172

152173
@property
153174
def instance(self):
@@ -175,7 +196,7 @@ def read_only(self, value):
175196
Args:
176197
value (bool): True for ReadOnly mode, False for ReadWrite.
177198
"""
178-
if self.inside_transaction:
199+
if self._spanner_transaction_started:
179200
raise ValueError(
180201
"Connection read/write mode can't be changed while a transaction is in progress. "
181202
"Commit or rollback the current transaction and try again."
@@ -213,7 +234,7 @@ def staleness(self, value):
213234
Args:
214235
value (dict): Staleness type and value.
215236
"""
216-
if self.inside_transaction:
237+
if self._spanner_transaction_started:
217238
raise ValueError(
218239
"`staleness` option can't be changed while a transaction is in progress. "
219240
"Commit or rollback the current transaction and try again."
@@ -331,15 +352,16 @@ def transaction_checkout(self):
331352
"""Get a Cloud Spanner transaction.
332353
333354
Begin a new transaction, if there is no transaction in
334-
this connection yet. Return the begun one otherwise.
355+
this connection yet. Return the started one otherwise.
335356
336-
The method is non operational in autocommit mode.
357+
This method is a no-op if the connection is in autocommit mode and no
358+
explicit transaction has been started
337359
338360
:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
339361
:returns: A Cloud Spanner transaction object, ready to use.
340362
"""
341-
if not self.autocommit:
342-
if not self.inside_transaction:
363+
if not self.read_only and self._client_transaction_started:
364+
if not self._spanner_transaction_started:
343365
self._transaction = self._session_checkout().transaction()
344366
self._transaction.begin()
345367

@@ -354,7 +376,7 @@ def snapshot_checkout(self):
354376
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
355377
:returns: A Cloud Spanner snapshot object, ready to use.
356378
"""
357-
if self.read_only and not self.autocommit:
379+
if self.read_only and self._client_transaction_started:
358380
if not self._snapshot:
359381
self._snapshot = Snapshot(
360382
self._session_checkout(), multi_use=True, **self.staleness
@@ -369,55 +391,80 @@ def close(self):
369391
The connection will be unusable from this point forward. If the
370392
connection has an active transaction, it will be rolled back.
371393
"""
372-
if self.inside_transaction:
394+
if self._spanner_transaction_started and not self.read_only:
373395
self._transaction.rollback()
374396

375397
if self._own_pool and self.database:
376398
self.database._pool.clear()
377399

378400
self.is_closed = True
379401

402+
@check_not_closed
403+
def begin(self):
404+
"""
405+
Marks the transaction as started.
406+
407+
:raises: :class:`InterfaceError`: if this connection is closed.
408+
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
409+
"""
410+
if self._transaction_begin_marked:
411+
raise OperationalError("A transaction has already started")
412+
if self._spanner_transaction_started:
413+
raise OperationalError(
414+
"Beginning a new transaction is not allowed when a transaction is already running"
415+
)
416+
self._transaction_begin_marked = True
417+
380418
def commit(self):
381419
"""Commits any pending transaction to the database.
382420
383-
This method is non-operational in autocommit mode.
421+
This is a no-op if there is no active client transaction.
384422
"""
385423
if self.database is None:
386424
raise ValueError("Database needs to be passed for this operation")
387-
self._snapshot = None
388425

389-
if self._autocommit:
390-
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
426+
if not self._client_transaction_started:
427+
warnings.warn(
428+
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
429+
)
391430
return
392431

393432
self.run_prior_DDL_statements()
394-
if self.inside_transaction:
433+
if self._spanner_transaction_started:
395434
try:
396-
if not self.read_only:
435+
if self.read_only:
436+
self._snapshot = None
437+
else:
397438
self._transaction.commit()
398439

399440
self._release_session()
400441
self._statements = []
442+
self._transaction_begin_marked = False
401443
except Aborted:
402444
self.retry_transaction()
403445
self.commit()
404446

405447
def rollback(self):
406448
"""Rolls back any pending transaction.
407449
408-
This is a no-op if there is no active transaction or if the connection
409-
is in autocommit mode.
450+
This is a no-op if there is no active client transaction.
410451
"""
411-
self._snapshot = None
412452

413-
if self._autocommit:
414-
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
415-
elif self._transaction:
416-
if not self.read_only:
453+
if not self._client_transaction_started:
454+
warnings.warn(
455+
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
456+
)
457+
return
458+
459+
if self._spanner_transaction_started:
460+
if self.read_only:
461+
self._snapshot = None
462+
else:
417463
self._transaction.rollback()
418464

419465
self._release_session()
420466
self._statements = []
467+
self._transaction_begin_marked = False
421468

422469
@check_not_closed
423470
def cursor(self):

google/cloud/spanner_dbapi/cursor.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def execute(self, sql, args=None):
250250
)
251251
if parsed_statement.statement_type == StatementType.DDL:
252252
self._batch_DDLs(sql)
253-
if self.connection.autocommit:
253+
if not self.connection._client_transaction_started:
254254
self.connection.run_prior_DDL_statements()
255255
return
256256

@@ -264,7 +264,7 @@ def execute(self, sql, args=None):
264264

265265
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
266266

267-
if not self.connection.autocommit:
267+
if self.connection._client_transaction_started:
268268
statement = Statement(
269269
sql,
270270
args,
@@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
348348
)
349349
statements.append((sql, params, get_param_types(params)))
350350

351-
if self.connection.autocommit:
351+
if not self.connection._client_transaction_started:
352352
self.connection.database.run_in_transaction(
353353
self._do_batch_update, statements, many_result_set
354354
)
@@ -396,7 +396,10 @@ def fetchone(self):
396396
sequence, or None when no more data is available."""
397397
try:
398398
res = next(self)
399-
if not self.connection.autocommit and not self.connection.read_only:
399+
if (
400+
self.connection._client_transaction_started
401+
and not self.connection.read_only
402+
):
400403
self._checksum.consume_result(res)
401404
return res
402405
except StopIteration:
@@ -414,7 +417,10 @@ def fetchall(self):
414417
res = []
415418
try:
416419
for row in self:
417-
if not self.connection.autocommit and not self.connection.read_only:
420+
if (
421+
self.connection._client_transaction_started
422+
and not self.connection.read_only
423+
):
418424
self._checksum.consume_result(row)
419425
res.append(row)
420426
except Aborted:
@@ -443,7 +449,10 @@ def fetchmany(self, size=None):
443449
for _ in range(size):
444450
try:
445451
res = next(self)
446-
if not self.connection.autocommit and not self.connection.read_only:
452+
if (
453+
self.connection._client_transaction_started
454+
and not self.connection.read_only
455+
):
447456
self._checksum.consume_result(res)
448457
items.append(res)
449458
except StopIteration:
@@ -473,7 +482,7 @@ def _handle_DQL(self, sql, params):
473482
if self.connection.database is None:
474483
raise ValueError("Database needs to be passed for this operation")
475484
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
476-
if self.connection.read_only and not self.connection.autocommit:
485+
if self.connection.read_only and self.connection._client_transaction_started:
477486
# initiate or use the existing multi-use snapshot
478487
self._handle_DQL_with_snapshot(
479488
self.connection.snapshot_checkout(), sql, params

google/cloud/spanner_dbapi/parsed_statement.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class StatementType(Enum):
2727
class ClientSideStatementType(Enum):
2828
COMMIT = 1
2929
BEGIN = 2
30+
ROLLBACK = 3
3031

3132

3233
@dataclass

0 commit comments

Comments
 (0)