Skip to content

Commit d06c593

Browse files
committed
feat: Inline Begin transction for RW transactions
1 parent 06725fc commit d06c593

File tree

7 files changed

+646
-75
lines changed

7 files changed

+646
-75
lines changed

google/cloud/spanner_v1/pool.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,6 @@ def begin_pending_transactions(self):
515515
"""Begin all transactions for sessions added to the pool."""
516516
while not self._pending_sessions.empty():
517517
session = self._pending_sessions.get()
518-
session._transaction.begin()
519518
super(TransactionPingingPool, self).put(session)
520519

521520

google/cloud/spanner_v1/session.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,7 @@ def run_in_transaction(self, func, *args, **kw):
352352
txn.transaction_tag = transaction_tag
353353
else:
354354
txn = self._transaction
355-
if txn._transaction_id is None:
356-
txn.begin()
357-
355+
358356
try:
359357
attempts += 1
360358
return_value = func(txn, *args, **kw)

google/cloud/spanner_v1/transaction.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def _check_state(self):
6161
:raises: :exc:`ValueError` if the object's state is invalid for making
6262
API requests.
6363
"""
64-
if self._transaction_id is None:
65-
raise ValueError("Transaction is not begun")
66-
64+
6765
if self.committed is not None:
6866
raise ValueError("Transaction is already committed")
6967

@@ -78,7 +76,11 @@ def _make_txn_selector(self):
7876
:returns: a selector configured for read-write transaction semantics.
7977
"""
8078
self._check_state()
81-
return TransactionSelector(id=self._transaction_id)
79+
80+
if self._transaction_id is None:
81+
return TransactionSelector(begin=TransactionOptions(read_write=TransactionOptions.ReadWrite()))
82+
else:
83+
return TransactionSelector(id=self._transaction_id)
8284

8385
def begin(self):
8486
"""Begin a transaction on the database.
@@ -111,15 +113,17 @@ def begin(self):
111113
def rollback(self):
112114
"""Roll back a transaction on the database."""
113115
self._check_state()
114-
database = self._session._database
115-
api = database.spanner_api
116-
metadata = _metadata_with_prefix(database.name)
117-
with trace_call("CloudSpanner.Rollback", self._session):
118-
api.rollback(
119-
session=self._session.name,
120-
transaction_id=self._transaction_id,
121-
metadata=metadata,
122-
)
116+
117+
if self._transaction_id is not None:
118+
database = self._session._database
119+
api = database.spanner_api
120+
metadata = _metadata_with_prefix(database.name)
121+
with trace_call("CloudSpanner.Rollback", self._session):
122+
api.rollback(
123+
session=self._session.name,
124+
transaction_id=self._transaction_id,
125+
metadata=metadata,
126+
)
123127
self.rolled_back = True
124128
del self._session._transaction
125129

@@ -142,6 +146,8 @@ def commit(self, return_commit_stats=False, request_options=None):
142146
:raises ValueError: if there are no mutations to commit.
143147
"""
144148
self._check_state()
149+
if self._transaction_id is None:
150+
self.begin()
145151

146152
database = self._session._database
147153
api = database.spanner_api
@@ -302,6 +308,10 @@ def execute_update(
302308
response = api.execute_sql(
303309
request=request, metadata=metadata, retry=retry, timeout=timeout
304310
)
311+
312+
if self._transaction_id is None and response.metadata.transaction is not None:
313+
self._transaction_id = response.metadata.transaction.id
314+
305315
return response.stats.row_count_exact
306316

307317
def batch_update(self, statements, request_options=None):
@@ -378,11 +388,15 @@ def batch_update(self, statements, request_options=None):
378388
row_counts = [
379389
result_set.stats.row_count_exact for result_set in response.result_sets
380390
]
391+
392+
for result_set in response.result_sets:
393+
if self._transaction_id is None and result_set.metadata.transaction is not None:
394+
self._transaction_id = result_set.metadata.transaction.id
395+
381396
return response.status, row_counts
382397

383398
def __enter__(self):
384399
"""Begin ``with`` block."""
385-
self.begin()
386400
return self
387401

388402
def __exit__(self, exc_type, exc_val, exc_tb):

tests/unit/test_pool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def test_bind(self):
656656
for session in SESSIONS:
657657
session.create.assert_not_called()
658658
txn = session._transaction
659-
txn.begin.assert_called_once_with()
659+
txn.begin.assert_not_called()
660660

661661
self.assertTrue(pool._pending_sessions.empty())
662662

@@ -685,7 +685,7 @@ def test_bind_w_timestamp_race(self):
685685
for session in SESSIONS:
686686
session.create.assert_not_called()
687687
txn = session._transaction
688-
txn.begin.assert_called_once_with()
688+
txn.begin.assert_not_called()
689689

690690
self.assertTrue(pool._pending_sessions.empty())
691691

@@ -771,7 +771,7 @@ def test_begin_pending_transactions_non_empty(self):
771771
pool.begin_pending_transactions() # no raise
772772

773773
for txn in TRANSACTIONS:
774-
txn.begin.assert_called_once_with()
774+
txn.begin.assert_not_called()
775775

776776
self.assertTrue(pending.empty())
777777

tests/unit/test_session.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -725,17 +725,6 @@ def unit_of_work(txn, *args, **kw):
725725
self.assertEqual(args, ())
726726
self.assertEqual(kw, {})
727727

728-
expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
729-
gax_api.begin_transaction.assert_called_once_with(
730-
session=self.SESSION_NAME,
731-
options=expected_options,
732-
metadata=[("google-cloud-resource-prefix", database.name)],
733-
)
734-
gax_api.rollback.assert_called_once_with(
735-
session=self.SESSION_NAME,
736-
transaction_id=TRANSACTION_ID,
737-
metadata=[("google-cloud-resource-prefix", database.name)],
738-
)
739728

740729
def test_run_in_transaction_callback_raises_non_abort_rpc_error(self):
741730
from google.api_core.exceptions import Cancelled
@@ -780,12 +769,6 @@ def unit_of_work(txn, *args, **kw):
780769
self.assertEqual(args, ())
781770
self.assertEqual(kw, {})
782771

783-
expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
784-
gax_api.begin_transaction.assert_called_once_with(
785-
session=self.SESSION_NAME,
786-
options=expected_options,
787-
metadata=[("google-cloud-resource-prefix", database.name)],
788-
)
789772
gax_api.rollback.assert_not_called()
790773

791774
def test_run_in_transaction_w_args_w_kwargs_wo_abort(self):
@@ -1141,16 +1124,10 @@ def unit_of_work(txn, *args, **kw):
11411124
self.assertEqual(kw, {})
11421125

11431126
expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
1144-
self.assertEqual(
1145-
gax_api.begin_transaction.call_args_list,
1146-
[
1147-
mock.call(
1148-
session=self.SESSION_NAME,
1149-
options=expected_options,
1150-
metadata=[("google-cloud-resource-prefix", database.name)],
1151-
)
1152-
]
1153-
* 2,
1127+
gax_api.begin_transaction.assert_called_once_with(
1128+
session=self.SESSION_NAME,
1129+
options=expected_options,
1130+
metadata=[("google-cloud-resource-prefix", database.name)],
11541131
)
11551132
request = CommitRequest(
11561133
session=self.SESSION_NAME,

0 commit comments

Comments
 (0)