Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
1 change: 1 addition & 0 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class BaseDialect(abc.ABC):
SUPPORTS_INDEXES: ClassVar[bool] = False
PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
DEFAULT_NUMERIC_PRECISION: ClassVar[int] = 0 # effective precision when type is just "NUMERIC"

PLACEHOLDER_TABLE = None # Used for Oracle

Expand Down
57 changes: 42 additions & 15 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class Dialect(BaseDialect):
}
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
# [BIG]NUMERIC, [BIG]NUMERIC(precision, scale), [BIG]NUMERIC(precision)
TYPE_NUMERIC_RE = re.compile(r"^((BIG)?NUMERIC)(?:\((\d+)(?:, (\d+))?\))?$")
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#parameterized_decimal_type
# The default scale is 9, which means a number can have up to 9 digits after the decimal point.
DEFAULT_NUMERIC_PRECISION = 9

def random(self) -> str:
return "RAND()"
Expand All @@ -94,21 +99,43 @@ def type_repr(self, t) -> str:

def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
col_type = super().parse_type(table_path, info)
if isinstance(col_type, UnknownColType):
m = self.TYPE_ARRAY_RE.fullmatch(info.data_type)
if m:
item_info = attrs.evolve(info, data_type=m.group(1))
item_type = self.parse_type(table_path, item_info)
col_type = Array(item_type=item_type)

# We currently ignore structs' structure, but later can parse it too. Examples:
# - STRUCT<INT64, STRING(10)> (unnamed)
# - STRUCT<foo INT64, bar STRING(10)> (named)
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
m = self.TYPE_STRUCT_RE.fullmatch(info.data_type)
if m:
col_type = Struct()
if not isinstance(col_type, UnknownColType):
return col_type

m = self.TYPE_ARRAY_RE.fullmatch(info.data_type)
if m:
item_info = attrs.evolve(info, data_type=m.group(1))
item_type = self.parse_type(table_path, item_info)
col_type = Array(item_type=item_type)
return col_type

# We currently ignore structs' structure, but later can parse it too. Examples:
# - STRUCT<INT64, STRING(10)> (unnamed)
# - STRUCT<foo INT64, bar STRING(10)> (named)
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
m = self.TYPE_STRUCT_RE.fullmatch(info.data_type)
if m:
col_type = Struct()
return col_type

m = self.TYPE_NUMERIC_RE.fullmatch(info.data_type)
if m:
precision = int(m.group(3)) if m.group(3) else None
scale = int(m.group(4)) if m.group(4) else None

if scale is not None:
# NUMERIC(..., scale) — scale is set explicitly
effective_precision = scale
elif precision is not None:
# NUMERIC(...) — scale is missing but precision is set
# effectively the same as NUMERIC(..., 0)
effective_precision = 0
else:
# NUMERIC → default scale is 9
effective_precision = 9
col_type = Decimal(precision=effective_precision)
return col_type

return col_type

Expand Down
4 changes: 4 additions & 0 deletions data_diff/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class Dialect(BaseDialect):
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True

# https://duckdb.org/docs/sql/data_types/numeric#fixed-point-decimals
# The default WIDTH and SCALE is DECIMAL(18, 3), if none are specified.
DEFAULT_NUMERIC_PRECISION = 3

TYPE_CLASSES = {
# Timestamps
"TIMESTAMP WITH TIME ZONE": TimestampTZ,
Expand Down
25 changes: 21 additions & 4 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class PostgresqlDialect(BaseDialect):
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
SUPPORTS_INDEXES = True

# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL
# without any precision or scale creates an “unconstrained numeric” column
# in which numeric values of any length can be stored, up to the implementation limits.
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-TABLE
DEFAULT_NUMERIC_PRECISION = 16383

TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
# Timestamps
"timestamp with time zone": TimestampTZ,
Expand Down Expand Up @@ -185,10 +191,21 @@ def select_table_schema(self, path: DbPath) -> str:
if database:
info_schema_path.insert(0, database)

return (
f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)
return f"""SELECT column_name, data_type, datetime_precision,
-- see comment for DEFAULT_NUMERIC_PRECISION
CASE
WHEN data_type = 'numeric'
THEN coalesce(numeric_precision, 131072 + {self.dialect.DEFAULT_NUMERIC_PRECISION})
ELSE numeric_precision
END AS numeric_precision,
CASE
WHEN data_type = 'numeric'
THEN coalesce(numeric_scale, {self.dialect.DEFAULT_NUMERIC_PRECISION})
ELSE numeric_scale
END AS numeric_scale
FROM {'.'.join(info_schema_path)}
WHERE table_name = '{table}' AND table_schema = '{schema}'
"""

def select_table_unique_columns(self, path: DbPath) -> str:
database, schema, table = self._normalize_table_path(path)
Expand Down
3 changes: 3 additions & 0 deletions data_diff/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class Dialect(BaseDialect):
"boolean": Boolean,
}

# https://www.vertica.com/docs/9.3.x/HTML/Content/Authoring/SQLReferenceManual/DataTypes/Numeric/NUMERIC.htm#Default
DEFAULT_NUMERIC_PRECISION = 15

def quote(self, s: str):
return f'"{s}"'

Expand Down
33 changes: 33 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,36 @@ def test_three_part_support(self):
d = db.query_table_schema(part.path)
assert len(d) == 1
db.query(part.drop())


@test_each_database
class TestNumericPrecisionParsing(unittest.TestCase):
def test_specified_precision(self):
name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
tbl = table(name, schema={"value": "DECIMAL(10, 2)"})
db.query(tbl.create())
t = table(name)
raw_schema = db.query_table_schema(t.path)
schema = db._process_table_schema(t.path, raw_schema)
self.assertEqual(schema["value"].precision, 2)

def test_specified_zero_precision(self):
name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
tbl = table(name, schema={"value": "DECIMAL(10)"})
db.query(tbl.create())
t = table(name)
raw_schema = db.query_table_schema(t.path)
schema = db._process_table_schema(t.path, raw_schema)
self.assertEqual(schema["value"].precision, 0)

def test_default_precision(self):
name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
tbl = table(name, schema={"value": "DECIMAL"})
db.query(tbl.create())
t = table(name)
raw_schema = db.query_table_schema(t.path)
schema = db._process_table_schema(t.path, raw_schema)
self.assertEqual(schema["value"].precision, db.dialect.DEFAULT_NUMERIC_PRECISION)