Skip to content

Commit a79786e

Browse files
authored
feat: add support for NUMERIC type (#86)
* feat: add support for NUMERIC type * add tests * fix test name * remove unused import * add NUMERIC to param_types * add system tests * test: update tests to work for emulator * style: fix lint Co-authored-by: larkee <larkee@users.noreply.github.com>
1 parent cbfcc8b commit a79786e

File tree

5 files changed

+155
-10
lines changed

5 files changed

+155
-10
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Helper functions for Cloud Spanner."""
1616

1717
import datetime
18+
import decimal
1819
import math
1920

2021
import six
@@ -127,6 +128,8 @@ def _make_value_pb(value):
127128
return Value(string_value=value)
128129
if isinstance(value, ListValue):
129130
return Value(list_value=value)
131+
if isinstance(value, decimal.Decimal):
132+
return Value(string_value=str(value))
130133
raise ValueError("Unknown type: %s" % (value,))
131134

132135

@@ -201,6 +204,8 @@ def _parse_value_pb(value_pb, field_type):
201204
_parse_value_pb(item_pb, field_type.struct_type.fields[i].type)
202205
for (i, item_pb) in enumerate(value_pb.list_value.values)
203206
]
207+
elif field_type.code == type_pb2.NUMERIC:
208+
result = decimal.Decimal(value_pb.string_value)
204209
else:
205210
raise ValueError("Unknown type: %s" % (field_type,))
206211
return result

google/cloud/spanner_v1/param_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
FLOAT64 = type_pb2.Type(code=type_pb2.FLOAT64)
2626
DATE = type_pb2.Type(code=type_pb2.DATE)
2727
TIMESTAMP = type_pb2.Type(code=type_pb2.TIMESTAMP)
28+
NUMERIC = type_pb2.Type(code=type_pb2.NUMERIC)
2829

2930

3031
def Array(element_type): # pylint: disable=invalid-name

tests/_fixtures.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,58 @@
1616

1717

1818
DDL = """\
19+
CREATE TABLE contacts (
20+
contact_id INT64,
21+
first_name STRING(1024),
22+
last_name STRING(1024),
23+
email STRING(1024) )
24+
PRIMARY KEY (contact_id);
25+
CREATE TABLE contact_phones (
26+
contact_id INT64,
27+
phone_type STRING(1024),
28+
phone_number STRING(1024) )
29+
PRIMARY KEY (contact_id, phone_type),
30+
INTERLEAVE IN PARENT contacts ON DELETE CASCADE;
31+
CREATE TABLE all_types (
32+
pkey INT64 NOT NULL,
33+
int_value INT64,
34+
int_array ARRAY<INT64>,
35+
bool_value BOOL,
36+
bool_array ARRAY<BOOL>,
37+
bytes_value BYTES(16),
38+
bytes_array ARRAY<BYTES(16)>,
39+
date_value DATE,
40+
date_array ARRAY<DATE>,
41+
float_value FLOAT64,
42+
float_array ARRAY<FLOAT64>,
43+
string_value STRING(16),
44+
string_array ARRAY<STRING(16)>,
45+
timestamp_value TIMESTAMP,
46+
timestamp_array ARRAY<TIMESTAMP>,
47+
numeric_value NUMERIC,
48+
numeric_array ARRAY<NUMERIC>)
49+
PRIMARY KEY (pkey);
50+
CREATE TABLE counters (
51+
name STRING(1024),
52+
value INT64 )
53+
PRIMARY KEY (name);
54+
CREATE TABLE string_plus_array_of_string (
55+
id INT64,
56+
name STRING(16),
57+
tags ARRAY<STRING(16)> )
58+
PRIMARY KEY (id);
59+
CREATE INDEX name ON contacts(first_name, last_name);
60+
CREATE TABLE users_history (
61+
id INT64 NOT NULL,
62+
commit_ts TIMESTAMP NOT NULL OPTIONS
63+
(allow_commit_timestamp=true),
64+
name STRING(MAX) NOT NULL,
65+
email STRING(MAX),
66+
deleted BOOL NOT NULL )
67+
PRIMARY KEY(id, commit_ts DESC);
68+
"""
69+
70+
EMULATOR_DDL = """\
1971
CREATE TABLE contacts (
2072
contact_id INT64,
2173
first_name STRING(1024),
@@ -66,3 +118,6 @@
66118
"""
67119

68120
DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(";") if stmt.strip()]
121+
EMULATOR_DDL_STATEMENTS = [
122+
stmt.strip() for stmt in EMULATOR_DDL.split(";") if stmt.strip()
123+
]

tests/system/test_system.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import collections
1616
import datetime
17+
import decimal
1718
import math
1819
import operator
1920
import os
@@ -38,6 +39,7 @@
3839
from google.cloud.spanner_v1.proto.type_pb2 import INT64
3940
from google.cloud.spanner_v1.proto.type_pb2 import STRING
4041
from google.cloud.spanner_v1.proto.type_pb2 import TIMESTAMP
42+
from google.cloud.spanner_v1.proto.type_pb2 import NUMERIC
4143
from google.cloud.spanner_v1.proto.type_pb2 import Type
4244

4345
from google.cloud._helpers import UTC
@@ -52,11 +54,13 @@
5254
from test_utils.retry import RetryResult
5355
from test_utils.system import unique_resource_id
5456
from tests._fixtures import DDL_STATEMENTS
57+
from tests._fixtures import EMULATOR_DDL_STATEMENTS
5558
from tests._helpers import OpenTelemetryBase, HAS_OPENTELEMETRY_INSTALLED
5659

5760

5861
CREATE_INSTANCE = os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None
5962
USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None
63+
SKIP_BACKUP_TESTS = os.getenv("SKIP_BACKUP_TESTS") is not None
6064

6165
if CREATE_INSTANCE:
6266
INSTANCE_ID = "google-cloud" + unique_resource_id("-")
@@ -92,7 +96,8 @@ class Config(object):
9296

9397

9498
def _has_all_ddl(database):
95-
return len(database.ddl_statements) == len(DDL_STATEMENTS)
99+
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
100+
return len(database.ddl_statements) == len(ddl_statements)
96101

97102

98103
def _list_instances():
@@ -284,8 +289,9 @@ class TestDatabaseAPI(unittest.TestCase, _TestData):
284289
@classmethod
285290
def setUpClass(cls):
286291
pool = BurstyPool(labels={"testcase": "database_api"})
292+
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
287293
cls._db = Config.INSTANCE.database(
288-
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
294+
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
289295
)
290296
operation = cls._db.create()
291297
operation.result(30) # raises on failure / timeout.
@@ -359,12 +365,13 @@ def test_update_database_ddl_with_operation_id(self):
359365
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
360366
create_op = temp_db.create()
361367
self.to_delete.append(temp_db)
368+
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
362369

363370
# We want to make sure the operation completes.
364371
create_op.result(240) # raises on failure / timeout.
365372
# random but shortish always start with letter
366373
operation_id = "a" + str(uuid.uuid4())[:8]
367-
operation = temp_db.update_ddl(DDL_STATEMENTS, operation_id=operation_id)
374+
operation = temp_db.update_ddl(ddl_statements, operation_id=operation_id)
368375

369376
self.assertEqual(operation_id, operation.operation.name.split("/")[-1])
370377

@@ -373,7 +380,7 @@ def test_update_database_ddl_with_operation_id(self):
373380

374381
temp_db.reload()
375382

376-
self.assertEqual(len(temp_db.ddl_statements), len(DDL_STATEMENTS))
383+
self.assertEqual(len(temp_db.ddl_statements), len(ddl_statements))
377384

378385
def test_db_batch_insert_then_db_snapshot_read(self):
379386
retry = RetryInstanceState(_has_all_ddl)
@@ -447,15 +454,17 @@ def _unit_of_work(transaction, name):
447454

448455

449456
@unittest.skipIf(USE_EMULATOR, "Skipping backup tests")
457+
@unittest.skipIf(SKIP_BACKUP_TESTS, "Skipping backup tests")
450458
class TestBackupAPI(unittest.TestCase, _TestData):
451459
DATABASE_NAME = "test_database" + unique_resource_id("_")
452460
DATABASE_NAME_2 = "test_database2" + unique_resource_id("_")
453461

454462
@classmethod
455463
def setUpClass(cls):
456464
pool = BurstyPool(labels={"testcase": "database_api"})
465+
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
457466
db1 = Config.INSTANCE.database(
458-
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
467+
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
459468
)
460469
db2 = Config.INSTANCE.database(cls.DATABASE_NAME_2, pool=pool)
461470
cls._db = db1
@@ -736,6 +745,8 @@ def test_list_backups(self):
736745
(OTHER_NAN,) = struct.unpack("<d", b"\x01\x00\x01\x00\x00\x00\xf8\xff")
737746
BYTES_1 = b"Ymlu"
738747
BYTES_2 = b"Ym9vdHM="
748+
NUMERIC_1 = decimal.Decimal("0.123456789")
749+
NUMERIC_2 = decimal.Decimal("1234567890")
739750
ALL_TYPES_TABLE = "all_types"
740751
ALL_TYPES_COLUMNS = (
741752
"pkey",
@@ -753,9 +764,18 @@ def test_list_backups(self):
753764
"string_array",
754765
"timestamp_value",
755766
"timestamp_array",
767+
"numeric_value",
768+
"numeric_array",
756769
)
770+
EMULATOR_ALL_TYPES_COLUMNS = ALL_TYPES_COLUMNS[:-2]
757771
AllTypesRowData = collections.namedtuple("AllTypesRowData", ALL_TYPES_COLUMNS)
758772
AllTypesRowData.__new__.__defaults__ = tuple([None for colum in ALL_TYPES_COLUMNS])
773+
EmulatorAllTypesRowData = collections.namedtuple(
774+
"EmulatorAllTypesRowData", EMULATOR_ALL_TYPES_COLUMNS
775+
)
776+
EmulatorAllTypesRowData.__new__.__defaults__ = tuple(
777+
[None for colum in EMULATOR_ALL_TYPES_COLUMNS]
778+
)
759779

760780
ALL_TYPES_ROWDATA = (
761781
# all nulls
@@ -769,6 +789,7 @@ def test_list_backups(self):
769789
AllTypesRowData(pkey=106, string_value=u"VALUE"),
770790
AllTypesRowData(pkey=107, timestamp_value=SOME_TIME),
771791
AllTypesRowData(pkey=108, timestamp_value=NANO_TIME),
792+
AllTypesRowData(pkey=109, numeric_value=NUMERIC_1),
772793
# empty array values
773794
AllTypesRowData(pkey=201, int_array=[]),
774795
AllTypesRowData(pkey=202, bool_array=[]),
@@ -777,6 +798,7 @@ def test_list_backups(self):
777798
AllTypesRowData(pkey=205, float_array=[]),
778799
AllTypesRowData(pkey=206, string_array=[]),
779800
AllTypesRowData(pkey=207, timestamp_array=[]),
801+
AllTypesRowData(pkey=208, numeric_array=[]),
780802
# non-empty array values, including nulls
781803
AllTypesRowData(pkey=301, int_array=[123, 456, None]),
782804
AllTypesRowData(pkey=302, bool_array=[True, False, None]),
@@ -785,6 +807,36 @@ def test_list_backups(self):
785807
AllTypesRowData(pkey=305, float_array=[3.1415926, 2.71828, None]),
786808
AllTypesRowData(pkey=306, string_array=[u"One", u"Two", None]),
787809
AllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]),
810+
AllTypesRowData(pkey=308, numeric_array=[NUMERIC_1, NUMERIC_2, None]),
811+
)
812+
EMULATOR_ALL_TYPES_ROWDATA = (
813+
# all nulls
814+
EmulatorAllTypesRowData(pkey=0),
815+
# Non-null values
816+
EmulatorAllTypesRowData(pkey=101, int_value=123),
817+
EmulatorAllTypesRowData(pkey=102, bool_value=False),
818+
EmulatorAllTypesRowData(pkey=103, bytes_value=BYTES_1),
819+
EmulatorAllTypesRowData(pkey=104, date_value=SOME_DATE),
820+
EmulatorAllTypesRowData(pkey=105, float_value=1.4142136),
821+
EmulatorAllTypesRowData(pkey=106, string_value=u"VALUE"),
822+
EmulatorAllTypesRowData(pkey=107, timestamp_value=SOME_TIME),
823+
EmulatorAllTypesRowData(pkey=108, timestamp_value=NANO_TIME),
824+
# empty array values
825+
EmulatorAllTypesRowData(pkey=201, int_array=[]),
826+
EmulatorAllTypesRowData(pkey=202, bool_array=[]),
827+
EmulatorAllTypesRowData(pkey=203, bytes_array=[]),
828+
EmulatorAllTypesRowData(pkey=204, date_array=[]),
829+
EmulatorAllTypesRowData(pkey=205, float_array=[]),
830+
EmulatorAllTypesRowData(pkey=206, string_array=[]),
831+
EmulatorAllTypesRowData(pkey=207, timestamp_array=[]),
832+
# non-empty array values, including nulls
833+
EmulatorAllTypesRowData(pkey=301, int_array=[123, 456, None]),
834+
EmulatorAllTypesRowData(pkey=302, bool_array=[True, False, None]),
835+
EmulatorAllTypesRowData(pkey=303, bytes_array=[BYTES_1, BYTES_2, None]),
836+
EmulatorAllTypesRowData(pkey=304, date_array=[SOME_DATE, None]),
837+
EmulatorAllTypesRowData(pkey=305, float_array=[3.1415926, 2.71828, None]),
838+
EmulatorAllTypesRowData(pkey=306, string_array=[u"One", u"Two", None]),
839+
EmulatorAllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]),
788840
)
789841

790842

@@ -794,8 +846,9 @@ class TestSessionAPI(OpenTelemetryBase, _TestData):
794846
@classmethod
795847
def setUpClass(cls):
796848
pool = BurstyPool(labels={"testcase": "session_api"})
849+
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
797850
cls._db = Config.INSTANCE.database(
798-
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
851+
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
799852
)
800853
operation = cls._db.create()
801854
operation.result(30) # raises on failure / timeout.
@@ -899,13 +952,19 @@ def test_batch_insert_then_read_all_datatypes(self):
899952
retry = RetryInstanceState(_has_all_ddl)
900953
retry(self._db.reload)()
901954

955+
if USE_EMULATOR:
956+
all_types_columns = EMULATOR_ALL_TYPES_COLUMNS
957+
all_types_rowdata = EMULATOR_ALL_TYPES_ROWDATA
958+
else:
959+
all_types_columns = ALL_TYPES_COLUMNS
960+
all_types_rowdata = ALL_TYPES_ROWDATA
902961
with self._db.batch() as batch:
903962
batch.delete(ALL_TYPES_TABLE, self.ALL)
904-
batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA)
963+
batch.insert(ALL_TYPES_TABLE, all_types_columns, all_types_rowdata)
905964

906965
with self._db.snapshot(read_timestamp=batch.committed) as snapshot:
907-
rows = list(snapshot.read(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, self.ALL))
908-
self._check_rows_data(rows, expected=ALL_TYPES_ROWDATA)
966+
rows = list(snapshot.read(ALL_TYPES_TABLE, all_types_columns, self.ALL))
967+
self._check_rows_data(rows, expected=all_types_rowdata)
909968

910969
def test_batch_insert_or_update_then_query(self):
911970
retry = RetryInstanceState(_has_all_ddl)
@@ -1704,9 +1763,10 @@ def test_read_w_index(self):
17041763
MY_COLUMNS = self.COLUMNS[0], self.COLUMNS[2]
17051764
EXTRA_DDL = ["CREATE INDEX contacts_by_last_name ON contacts(last_name)"]
17061765
pool = BurstyPool(labels={"testcase": "read_w_index"})
1766+
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
17071767
temp_db = Config.INSTANCE.database(
17081768
"test_read" + unique_resource_id("_"),
1709-
ddl_statements=DDL_STATEMENTS + EXTRA_DDL,
1769+
ddl_statements=ddl_statements + EXTRA_DDL,
17101770
pool=pool,
17111771
)
17121772
operation = temp_db.create()
@@ -2282,6 +2342,10 @@ def test_execute_sql_w_date_bindings(self):
22822342
dates = [SOME_DATE, SOME_DATE + datetime.timedelta(days=1)]
22832343
self._bind_test_helper(DATE, SOME_DATE, dates)
22842344

2345+
@unittest.skipIf(USE_EMULATOR, "Skipping NUMERIC")
2346+
def test_execute_sql_w_numeric_bindings(self):
2347+
self._bind_test_helper(NUMERIC, NUMERIC_1, [NUMERIC_1, NUMERIC_2])
2348+
22852349
def test_execute_sql_w_query_param_struct(self):
22862350
NAME = "Phred"
22872351
COUNT = 123

tests/unit/test__helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,15 @@ def test_w_datetime(self):
208208
self.assertIsInstance(value_pb, Value)
209209
self.assertEqual(value_pb.string_value, datetime_helpers.to_rfc3339(now))
210210

211+
def test_w_numeric(self):
212+
import decimal
213+
from google.protobuf.struct_pb2 import Value
214+
215+
value = decimal.Decimal("9999999999999999999999999999.999999999")
216+
value_pb = self._callFUT(value)
217+
self.assertIsInstance(value_pb, Value)
218+
self.assertEqual(value_pb.string_value, str(value))
219+
211220
def test_w_unknown_type(self):
212221
with self.assertRaises(ValueError):
213222
self._callFUT(object())
@@ -431,6 +440,17 @@ def test_w_struct(self):
431440

432441
self.assertEqual(self._callFUT(value_pb, field_type), VALUES)
433442

443+
def test_w_numeric(self):
444+
import decimal
445+
from google.protobuf.struct_pb2 import Value
446+
from google.cloud.spanner_v1.proto.type_pb2 import Type, NUMERIC
447+
448+
VALUE = decimal.Decimal("99999999999999999999999999999.999999999")
449+
field_type = Type(code=NUMERIC)
450+
value_pb = Value(string_value=str(VALUE))
451+
452+
self.assertEqual(self._callFUT(value_pb, field_type), VALUE)
453+
434454
def test_w_unknown_type(self):
435455
from google.protobuf.struct_pb2 import Value
436456
from google.cloud.spanner_v1.proto.type_pb2 import Type

0 commit comments

Comments
 (0)