Skip to content

Commit e16f376

Browse files
vi3k6i5larkee
andauthored
feat: adding support for spanner request options tags (#276)
* feat: added support for request options with request tag and transaction tag in supported classes * feat: corrected import for RequestOptions * feat: request options added lint corrections * feat: added system test for request tagging * feat: added annotation to skip request tags validation test while using emulator * feat: lint fix * fix: remove request_option from batch * lint: lint fixes * refactor: undo changes * refactor: undo changes * refactor: remove test_system file, as it has been removed in master * refactor: update code to latest changes * feat: added support for request options with request tag and transaction tag in supported classes * feat: corrected import for RequestOptions * fix: add transaction_tag test for transaction_tag set in transaction class * fix: lint fixes * refactor: lint fixes * fix: change request_options dictionary to RequestOptions object * refactor: fix lint issues * refactor: lint fixes * refactor: move write txn properties to BatchBase * fix: use transaction tag on all write methods * feat: add support for batch commit * feat: add support for setting a transaction tag on batch checkout * refactor: update checks for readability * test: use separate expectation object for readability * test: add run_in_transaction test * test: remove test for unsupported behaviour * style: lint fixes Co-authored-by: larkee <larkee@users.noreply.github.com>
1 parent f59d08b commit e16f376

File tree

10 files changed

+425
-25
lines changed

10 files changed

+425
-25
lines changed

google/cloud/spanner_v1/batch.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class _BatchBase(_SessionWrapper):
3232
:param session: the session used to perform the commit
3333
"""
3434

35+
transaction_tag = None
36+
_read_only = False
37+
3538
def __init__(self, session):
3639
super(_BatchBase, self).__init__(session)
3740
self._mutations = []
@@ -118,8 +121,7 @@ def delete(self, table, keyset):
118121

119122

120123
class Batch(_BatchBase):
121-
"""Accumulate mutations for transmission during :meth:`commit`.
122-
"""
124+
"""Accumulate mutations for transmission during :meth:`commit`."""
123125

124126
committed = None
125127
commit_stats = None
@@ -160,8 +162,14 @@ def commit(self, return_commit_stats=False, request_options=None):
160162
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
161163
trace_attributes = {"num_mutations": len(self._mutations)}
162164

163-
if type(request_options) == dict:
165+
if request_options is None:
166+
request_options = RequestOptions()
167+
elif type(request_options) == dict:
164168
request_options = RequestOptions(request_options)
169+
request_options.transaction_tag = self.transaction_tag
170+
171+
# Request tags are not supported for commit requests.
172+
request_options.request_tag = None
165173

166174
request = CommitRequest(
167175
session=self._session.name,

google/cloud/spanner_v1/database.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,15 +494,20 @@ def execute_partitioned_dml(
494494
(Optional) Common options for this request.
495495
If a dict is provided, it must be of the same form as the protobuf
496496
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
497+
Please note, the `transactionTag` setting will be ignored as it is
498+
not supported for partitioned DML.
497499
498500
:rtype: int
499501
:returns: Count of rows affected by the DML statement.
500502
"""
501503
query_options = _merge_query_options(
502504
self._instance._client._query_options, query_options
503505
)
504-
if type(request_options) == dict:
506+
if request_options is None:
507+
request_options = RequestOptions()
508+
elif type(request_options) == dict:
505509
request_options = RequestOptions(request_options)
510+
request_options.transaction_tag = None
506511

507512
if params is not None:
508513
from google.cloud.spanner_v1.transaction import Transaction
@@ -796,12 +801,19 @@ class BatchCheckout(object):
796801
def __init__(self, database, request_options=None):
797802
self._database = database
798803
self._session = self._batch = None
799-
self._request_options = request_options
804+
if request_options is None:
805+
self._request_options = RequestOptions()
806+
elif type(request_options) == dict:
807+
self._request_options = RequestOptions(request_options)
808+
else:
809+
self._request_options = request_options
800810

801811
def __enter__(self):
802812
"""Begin ``with`` block."""
803813
session = self._session = self._database._pool.get()
804814
batch = self._batch = Batch(session)
815+
if self._request_options.transaction_tag:
816+
batch.transaction_tag = self._request_options.transaction_tag
805817
return batch
806818

807819
def __exit__(self, exc_type, exc_val, exc_tb):

google/cloud/spanner_v1/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,13 @@ def run_in_transaction(self, func, *args, **kw):
340340
"""
341341
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
342342
commit_request_options = kw.pop("commit_request_options", None)
343+
transaction_tag = kw.pop("transaction_tag", None)
343344
attempts = 0
344345

345346
while True:
346347
if self._transaction is None:
347348
txn = self.transaction()
349+
txn.transaction_tag = transaction_tag
348350
else:
349351
txn = self._transaction
350352
if txn._transaction_id is None:

google/cloud/spanner_v1/snapshot.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class _SnapshotBase(_SessionWrapper):
102102
"""
103103

104104
_multi_use = False
105+
_read_only = True
105106
_transaction_id = None
106107
_read_request_count = 0
107108
_execute_sql_count = 0
@@ -160,6 +161,8 @@ def read(
160161
(Optional) Common options for this request.
161162
If a dict is provided, it must be of the same form as the protobuf
162163
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
164+
Please note, the `transactionTag` setting will be ignored for
165+
snapshot as it's not supported for read-only transactions.
163166
164167
:type retry: :class:`~google.api_core.retry.Retry`
165168
:param retry: (Optional) The retry settings for this request.
@@ -185,9 +188,17 @@ def read(
185188
metadata = _metadata_with_prefix(database.name)
186189
transaction = self._make_txn_selector()
187190

188-
if type(request_options) == dict:
191+
if request_options is None:
192+
request_options = RequestOptions()
193+
elif type(request_options) == dict:
189194
request_options = RequestOptions(request_options)
190195

196+
if self._read_only:
197+
# Transaction tags are not supported for read only transactions.
198+
request_options.transaction_tag = None
199+
else:
200+
request_options.transaction_tag = self.transaction_tag
201+
191202
request = ReadRequest(
192203
session=self._session.name,
193204
table=table,
@@ -312,8 +323,15 @@ def execute_sql(
312323
default_query_options = database._instance._client._query_options
313324
query_options = _merge_query_options(default_query_options, query_options)
314325

315-
if type(request_options) == dict:
326+
if request_options is None:
327+
request_options = RequestOptions()
328+
elif type(request_options) == dict:
316329
request_options = RequestOptions(request_options)
330+
if self._read_only:
331+
# Transaction tags are not supported for read only transactions.
332+
request_options.transaction_tag = None
333+
else:
334+
request_options.transaction_tag = self.transaction_tag
317335

318336
request = ExecuteSqlRequest(
319337
session=self._session.name,

google/cloud/spanner_v1/transaction.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,15 @@ def commit(self, return_commit_stats=False, request_options=None):
148148
metadata = _metadata_with_prefix(database.name)
149149
trace_attributes = {"num_mutations": len(self._mutations)}
150150

151-
if type(request_options) == dict:
151+
if request_options is None:
152+
request_options = RequestOptions()
153+
elif type(request_options) == dict:
152154
request_options = RequestOptions(request_options)
155+
if self.transaction_tag is not None:
156+
request_options.transaction_tag = self.transaction_tag
157+
158+
# Request tags are not supported for commit requests.
159+
request_options.request_tag = None
153160

154161
request = CommitRequest(
155162
session=self._session.name,
@@ -267,8 +274,11 @@ def execute_update(
267274
default_query_options = database._instance._client._query_options
268275
query_options = _merge_query_options(default_query_options, query_options)
269276

270-
if type(request_options) == dict:
277+
if request_options is None:
278+
request_options = RequestOptions()
279+
elif type(request_options) == dict:
271280
request_options = RequestOptions(request_options)
281+
request_options.transaction_tag = self.transaction_tag
272282

273283
trace_attributes = {"db.statement": dml}
274284

@@ -343,8 +353,11 @@ def batch_update(self, statements, request_options=None):
343353
self._execute_sql_count + 1,
344354
)
345355

346-
if type(request_options) == dict:
356+
if request_options is None:
357+
request_options = RequestOptions()
358+
elif type(request_options) == dict:
347359
request_options = RequestOptions(request_options)
360+
request_options.transaction_tag = self.transaction_tag
348361

349362
trace_attributes = {
350363
# Get just the queries from the DML statement batch

tests/unit/test_batch.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import unittest
1717
from tests._helpers import OpenTelemetryBase, StatusCode
18+
from google.cloud.spanner_v1 import RequestOptions
1819

1920
TABLE_NAME = "citizens"
2021
COLUMNS = ["email", "first_name", "last_name", "age"]
@@ -39,6 +40,7 @@ class _BaseTest(unittest.TestCase):
3940
DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID
4041
SESSION_ID = "session-id"
4142
SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
43+
TRANSACTION_TAG = "transaction-tag"
4244

4345
def _make_one(self, *args, **kwargs):
4446
return self._getTargetClass()(*args, **kwargs)
@@ -232,18 +234,87 @@ def test_commit_ok(self):
232234
self.assertEqual(committed, now)
233235
self.assertEqual(batch.committed, committed)
234236

235-
(session, mutations, single_use_txn, metadata, request_options) = api._committed
237+
(session, mutations, single_use_txn, request_options, metadata) = api._committed
236238
self.assertEqual(session, self.SESSION_NAME)
237239
self.assertEqual(mutations, batch._mutations)
238240
self.assertIsInstance(single_use_txn, TransactionOptions)
239241
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
240242
self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)])
241-
self.assertEqual(request_options, None)
243+
self.assertEqual(request_options, RequestOptions())
242244

243245
self.assertSpanAttributes(
244246
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
245247
)
246248

249+
def _test_commit_with_request_options(self, request_options=None):
250+
import datetime
251+
from google.cloud.spanner_v1 import CommitResponse
252+
from google.cloud.spanner_v1 import TransactionOptions
253+
from google.cloud._helpers import UTC
254+
from google.cloud._helpers import _datetime_to_pb_timestamp
255+
256+
now = datetime.datetime.utcnow().replace(tzinfo=UTC)
257+
now_pb = _datetime_to_pb_timestamp(now)
258+
response = CommitResponse(commit_timestamp=now_pb)
259+
database = _Database()
260+
api = database.spanner_api = _FauxSpannerAPI(_commit_response=response)
261+
session = _Session(database)
262+
batch = self._make_one(session)
263+
batch.transaction_tag = self.TRANSACTION_TAG
264+
batch.insert(TABLE_NAME, COLUMNS, VALUES)
265+
committed = batch.commit(request_options=request_options)
266+
267+
self.assertEqual(committed, now)
268+
self.assertEqual(batch.committed, committed)
269+
270+
if type(request_options) == dict:
271+
expected_request_options = RequestOptions(request_options)
272+
else:
273+
expected_request_options = request_options
274+
expected_request_options.transaction_tag = self.TRANSACTION_TAG
275+
expected_request_options.request_tag = None
276+
277+
(
278+
session,
279+
mutations,
280+
single_use_txn,
281+
actual_request_options,
282+
metadata,
283+
) = api._committed
284+
self.assertEqual(session, self.SESSION_NAME)
285+
self.assertEqual(mutations, batch._mutations)
286+
self.assertIsInstance(single_use_txn, TransactionOptions)
287+
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
288+
self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)])
289+
self.assertEqual(actual_request_options, expected_request_options)
290+
291+
self.assertSpanAttributes(
292+
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
293+
)
294+
295+
def test_commit_w_request_tag_success(self):
296+
request_options = RequestOptions(request_tag="tag-1",)
297+
self._test_commit_with_request_options(request_options=request_options)
298+
299+
def test_commit_w_transaction_tag_success(self):
300+
request_options = RequestOptions(transaction_tag="tag-1-1",)
301+
self._test_commit_with_request_options(request_options=request_options)
302+
303+
def test_commit_w_request_and_transaction_tag_success(self):
304+
request_options = RequestOptions(
305+
request_tag="tag-1", transaction_tag="tag-1-1",
306+
)
307+
self._test_commit_with_request_options(request_options=request_options)
308+
309+
def test_commit_w_request_and_transaction_tag_dictionary_success(self):
310+
request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"}
311+
self._test_commit_with_request_options(request_options=request_options)
312+
313+
def test_commit_w_incorrect_tag_dictionary_error(self):
314+
request_options = {"incorrect_tag": "tag-1-1"}
315+
with self.assertRaises(ValueError):
316+
self._test_commit_with_request_options(request_options=request_options)
317+
247318
def test_context_mgr_already_committed(self):
248319
import datetime
249320
from google.cloud._helpers import UTC
@@ -281,13 +352,13 @@ def test_context_mgr_success(self):
281352

282353
self.assertEqual(batch.committed, now)
283354

284-
(session, mutations, single_use_txn, metadata, request_options) = api._committed
355+
(session, mutations, single_use_txn, request_options, metadata) = api._committed
285356
self.assertEqual(session, self.SESSION_NAME)
286357
self.assertEqual(mutations, batch._mutations)
287358
self.assertIsInstance(single_use_txn, TransactionOptions)
288359
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
289360
self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)])
290-
self.assertEqual(request_options, None)
361+
self.assertEqual(request_options, RequestOptions())
291362

292363
self.assertSpanAttributes(
293364
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
@@ -341,7 +412,7 @@ def __init__(self, **kwargs):
341412
self.__dict__.update(**kwargs)
342413

343414
def commit(
344-
self, request=None, metadata=None, request_options=None,
415+
self, request=None, metadata=None,
345416
):
346417
from google.api_core.exceptions import Unknown
347418

@@ -350,8 +421,8 @@ def commit(
350421
request.session,
351422
request.mutations,
352423
request.single_use_transaction,
424+
request.request_options,
353425
metadata,
354-
request_options,
355426
)
356427
if self._rpc_error:
357428
raise Unknown("error")

0 commit comments

Comments
 (0)