Skip to content

Commit cb25de4

Browse files
perf: Skip gRPC trailers for StreamingRead & ExecuteStreamingSql (#1385)
* perf: Skip gRPC trailers for StreamingRead & ExecuteStreamingSql * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * add mockspanner tests * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Fix None issue * Optimize imports * optimize imports * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Remove setup * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Remove .python-version --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 21f5028 commit cb25de4

File tree

5 files changed

+145
-26
lines changed

5 files changed

+145
-26
lines changed

google/cloud/spanner_v1/streamed.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
self._column_info = column_info # Column information
5454
self._field_decoders = None
5555
self._lazy_decode = lazy_decode # Return protobuf values
56+
self._done = False
5657

5758
@property
5859
def fields(self):
@@ -154,11 +155,16 @@ def _consume_next(self):
154155

155156
self._merge_values(values)
156157

158+
if response_pb.last:
159+
self._done = True
160+
157161
def __iter__(self):
158162
while True:
159163
iter_rows, self._rows[:] = self._rows[:], ()
160164
while iter_rows:
161165
yield iter_rows.pop(0)
166+
if self._done:
167+
return
162168
try:
163169
self._consume_next()
164170
except StopIteration:

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@
3535
class MockSpanner:
3636
def __init__(self):
3737
self.results = {}
38+
self.execute_streaming_sql_results = {}
3839
self.errors = {}
3940

4041
def add_result(self, sql: str, result: result_set.ResultSet):
4142
self.results[sql.lower().strip()] = result
4243

44+
def add_execute_streaming_sql_results(
45+
self, sql: str, partial_result_sets: list[result_set.PartialResultSet]
46+
):
47+
self.execute_streaming_sql_results[sql.lower().strip()] = partial_result_sets
48+
4349
def get_result(self, sql: str) -> result_set.ResultSet:
4450
result = self.results.get(sql.lower().strip())
4551
if result is None:
@@ -55,9 +61,20 @@ def pop_error(self, context):
5561
if error:
5662
context.abort_with_status(error)
5763

58-
def get_result_as_partial_result_sets(
64+
def get_execute_streaming_sql_results(
5965
self, sql: str, started_transaction: transaction.Transaction
60-
) -> [result_set.PartialResultSet]:
66+
) -> list[result_set.PartialResultSet]:
67+
if self.execute_streaming_sql_results.get(sql.lower().strip()):
68+
partials = self.execute_streaming_sql_results[sql.lower().strip()]
69+
else:
70+
partials = self.get_result_as_partial_result_sets(sql)
71+
if started_transaction:
72+
partials[0].metadata.transaction = started_transaction
73+
return partials
74+
75+
def get_result_as_partial_result_sets(
76+
self, sql: str
77+
) -> list[result_set.PartialResultSet]:
6178
result: result_set.ResultSet = self.get_result(sql)
6279
partials = []
6380
first = True
@@ -70,11 +87,10 @@ def get_result_as_partial_result_sets(
7087
partial = result_set.PartialResultSet()
7188
if first:
7289
partial.metadata = ResultSetMetadata(result.metadata)
90+
first = False
7391
partial.values.extend(row)
7492
partials.append(partial)
7593
partials[len(partials) - 1].stats = result.stats
76-
if started_transaction:
77-
partials[0].metadata.transaction = started_transaction
7894
return partials
7995

8096

@@ -149,7 +165,7 @@ def ExecuteStreamingSql(self, request, context):
149165
self._requests.append(request)
150166
self.mock_spanner.pop_error(context)
151167
started_transaction = self.__maybe_create_transaction(request)
152-
partials = self.mock_spanner.get_result_as_partial_result_sets(
168+
partials = self.mock_spanner.get_execute_streaming_sql_results(
153169
request.sql, started_transaction
154170
)
155171
for result in partials:

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,34 @@
1414

1515
import unittest
1616

17-
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
18-
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
19-
from google.cloud.spanner_v1.testing.mock_spanner import (
20-
start_mock_server,
21-
SpannerServicer,
22-
)
23-
import google.cloud.spanner_v1.types.type as spanner_type
24-
import google.cloud.spanner_v1.types.result_set as result_set
17+
import grpc
2518
from google.api_core.client_options import ClientOptions
2619
from google.auth.credentials import AnonymousCredentials
27-
from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool
28-
from google.cloud.spanner_v1.database import Database
29-
from google.cloud.spanner_v1.instance import Instance
30-
import grpc
31-
from google.rpc import code_pb2
32-
from google.rpc import status_pb2
33-
from google.rpc.error_details_pb2 import RetryInfo
20+
from google.cloud.spanner_v1 import Type
21+
22+
from google.cloud.spanner_v1 import StructType
23+
from google.cloud.spanner_v1._helpers import _make_value_pb
24+
25+
from google.cloud.spanner_v1 import PartialResultSet
3426
from google.protobuf.duration_pb2 import Duration
27+
from google.rpc import code_pb2, status_pb2
28+
29+
from google.rpc.error_details_pb2 import RetryInfo
3530
from grpc_status._common import code_to_grpc_status_code
3631
from grpc_status.rpc_status import _Status
3732

33+
import google.cloud.spanner_v1.types.result_set as result_set
34+
import google.cloud.spanner_v1.types.type as spanner_type
35+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
36+
from google.cloud.spanner_v1 import Client, FixedSizePool, ResultSetMetadata, TypeCode
37+
from google.cloud.spanner_v1.database import Database
38+
from google.cloud.spanner_v1.instance import Instance
39+
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
40+
from google.cloud.spanner_v1.testing.mock_spanner import (
41+
SpannerServicer,
42+
start_mock_server,
43+
)
44+
3845

3946
# Creates an aborted status with the smallest possible retry delay.
4047
def aborted_status() -> _Status:
@@ -57,6 +64,27 @@ def aborted_status() -> _Status:
5764
return status
5865

5966

67+
def _make_partial_result_sets(
68+
fields: list[tuple[str, TypeCode]], results: list[dict]
69+
) -> list[result_set.PartialResultSet]:
70+
partial_result_sets = []
71+
for result in results:
72+
partial_result_set = PartialResultSet()
73+
if len(partial_result_sets) == 0:
74+
# setting the metadata
75+
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
76+
for field in fields:
77+
metadata.row_type.fields.append(
78+
StructType.Field(name=field[0], type_=Type(code=field[1]))
79+
)
80+
partial_result_set.metadata = metadata
81+
for value in result["values"]:
82+
partial_result_set.values.append(_make_value_pb(value))
83+
partial_result_set.last = result.get("last") or False
84+
partial_result_sets.append(partial_result_set)
85+
return partial_result_sets
86+
87+
6088
# Creates an UNAVAILABLE status with the smallest possible retry delay.
6189
def unavailable_status() -> _Status:
6290
error = status_pb2.Status(
@@ -101,6 +129,14 @@ def add_select1_result():
101129
add_single_result("select 1", "c", TypeCode.INT64, [("1",)])
102130

103131

132+
def add_execute_streaming_sql_results(
133+
sql: str, partial_result_sets: list[result_set.PartialResultSet]
134+
):
135+
MockServerTestBase.spanner_service.mock_spanner.add_execute_streaming_sql_results(
136+
sql, partial_result_sets
137+
)
138+
139+
104140
def add_single_result(
105141
sql: str, column_name: str, type_code: spanner_type.TypeCode, row
106142
):

tests/mockserver_tests/test_basics.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,24 @@
1717
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
1818
from google.cloud.spanner_v1 import (
1919
BatchCreateSessionsRequest,
20-
ExecuteSqlRequest,
2120
BeginTransactionRequest,
22-
TransactionOptions,
2321
ExecuteBatchDmlRequest,
22+
ExecuteSqlRequest,
23+
TransactionOptions,
2424
TypeCode,
2525
)
26-
from google.cloud.spanner_v1.transaction import Transaction
2726
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
27+
from google.cloud.spanner_v1.transaction import Transaction
2828

2929
from tests.mockserver_tests.mock_server_test_base import (
3030
MockServerTestBase,
31+
_make_partial_result_sets,
3132
add_select1_result,
33+
add_single_result,
3234
add_update_count,
3335
add_error,
3436
unavailable_status,
35-
add_single_result,
37+
add_execute_streaming_sql_results,
3638
)
3739

3840

@@ -176,6 +178,31 @@ def test_last_statement_query(self):
176178
self.assertEqual(1, len(requests), msg=requests)
177179
self.assertTrue(requests[0].last_statement, requests[0])
178180

181+
def test_execute_streaming_sql_last_field(self):
182+
partial_result_sets = _make_partial_result_sets(
183+
[("ID", TypeCode.INT64), ("NAME", TypeCode.STRING)],
184+
[
185+
{"values": ["1", "ABC", "2", "DEF"]},
186+
{"values": ["3", "GHI"], "last": True},
187+
],
188+
)
189+
190+
sql = "select * from my_table"
191+
add_execute_streaming_sql_results(sql, partial_result_sets)
192+
count = 1
193+
with self.database.snapshot() as snapshot:
194+
results = snapshot.execute_sql(sql)
195+
result_list = []
196+
for row in results:
197+
result_list.append(row)
198+
self.assertEqual(count, row[0])
199+
count += 1
200+
self.assertEqual(3, len(result_list))
201+
requests = self.spanner_service.requests
202+
self.assertEqual(2, len(requests), msg=requests)
203+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
204+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
205+
179206

180207
def _execute_query(transaction: Transaction, sql: str):
181208
rows = transaction.execute_sql(sql, last_statement=True)

tests/unit/test_streamed.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ def _make_result_set_stats(query_plan=None, **kw):
122122

123123
@staticmethod
124124
def _make_partial_result_set(
125-
values, metadata=None, stats=None, chunked_value=False
125+
values, metadata=None, stats=None, chunked_value=False, last=False
126126
):
127127
from google.cloud.spanner_v1 import PartialResultSet
128128

129129
results = PartialResultSet(
130-
metadata=metadata, stats=stats, chunked_value=chunked_value
130+
metadata=metadata, stats=stats, chunked_value=chunked_value, last=last
131131
)
132132
for v in values:
133133
results.values.append(v)
@@ -162,6 +162,40 @@ def test__merge_chunk_bool(self):
162162
with self.assertRaises(Unmergeable):
163163
streamed._merge_chunk(chunk)
164164

165+
def test__PartialResultSetWithLastFlag(self):
166+
from google.cloud.spanner_v1 import TypeCode
167+
168+
fields = [
169+
self._make_scalar_field("ID", TypeCode.INT64),
170+
self._make_scalar_field("NAME", TypeCode.STRING),
171+
]
172+
for length in range(4, 6):
173+
metadata = self._make_result_set_metadata(fields)
174+
result_sets = [
175+
self._make_partial_result_set(
176+
[self._make_value(0), "google_0"], metadata=metadata
177+
)
178+
]
179+
for i in range(1, 5):
180+
bares = [i]
181+
values = [
182+
[self._make_value(bare), "google_" + str(bare)] for bare in bares
183+
]
184+
result_sets.append(
185+
self._make_partial_result_set(
186+
*values, metadata=metadata, last=(i == length - 1)
187+
)
188+
)
189+
190+
iterator = _MockCancellableIterator(*result_sets)
191+
streamed = self._make_one(iterator)
192+
count = 0
193+
for row in streamed:
194+
self.assertEqual(row[0], count)
195+
self.assertEqual(row[1], "google_" + str(count))
196+
count += 1
197+
self.assertEqual(count, length)
198+
165199
def test__merge_chunk_numeric(self):
166200
from google.cloud.spanner_v1 import TypeCode
167201

0 commit comments

Comments
 (0)