Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
6 changes: 5 additions & 1 deletion data_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@


def connect_to_table(
db_info: Union[str, dict], table_name: Union[DbPath, str], key_column: str = "id", thread_count: Optional[int] = 1, **kwargs
db_info: Union[str, dict],
table_name: Union[DbPath, str],
key_column: str = "id",
thread_count: Optional[int] = 1,
**kwargs,
):
"""Connects to the given database, and creates a TableSegment instance

Expand Down
5 changes: 3 additions & 2 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
"-": "red",
}


def _remove_passwords_in_dict(d: dict):
for k, v in d.items():
if k == 'password':
d[k] = '*' * len(v)
if k == "password":
d[k] = "*" * len(v)
elif isinstance(v, dict):
_remove_passwords_in_dict(v)

Expand Down
15 changes: 12 additions & 3 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TemporalType,
UnknownColType,
Text,
DbTime,
)
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName

Expand Down Expand Up @@ -151,9 +152,10 @@ def _parse_type(

elif issubclass(cls, Decimal):
if numeric_scale is None:
raise ValueError(
f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
)
numeric_scale = 0 # Needed for Oracle.
# raise ValueError(
# f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
# )
return cls(precision=numeric_scale)

elif issubclass(cls, Float):
Expand Down Expand Up @@ -242,6 +244,13 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None

return f"LIMIT {limit}"

def concat(self, l: List[str]) -> str:
joined_exprs = ", ".join(l)
return f"concat({joined_exprs})"

def timestamp_value(self, t: DbTime) -> str:
return "'%s'" % t.isoformat()

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
if isinstance(coltype, String_UUID):
return f"TRIM({value})"
Expand Down
10 changes: 6 additions & 4 deletions data_diff/databases/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def match_path(self, dsn):
"presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://<user>@<host>/<catalog>/<schema>"),
"bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery://<project>/<dataset>"),
"databricks": MatchUriPath(
Databricks, ["catalog", "schema"], help_str="databricks://:access_token@server_name/http_path",
Databricks,
["catalog", "schema"],
help_str="databricks://:access_token@server_name/http_path",
),
"trino": MatchUriPath(Trino, ["catalog", "schema"], help_str="trino://<user>@<host>/<catalog>/<schema>"),
}
Expand Down Expand Up @@ -125,9 +127,9 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
if scheme == "databricks":
assert not dsn.user
kw = {}
kw['access_token'] = dsn.password
kw['http_path'] = dsn.path
kw['server_hostname'] = dsn.host
kw["access_token"] = dsn.password
kw["http_path"] = dsn.path
kw["server_hostname"] = dsn.host
kw.update(dsn.query)
else:
kw = matcher.match_path(dsn)
Expand Down
13 changes: 12 additions & 1 deletion data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import decimal
from abc import ABC, abstractmethod
from typing import Sequence, Optional, Tuple, Union, Dict, Any
from typing import Sequence, Optional, Tuple, Union, Dict, List
from datetime import datetime

from runtype import dataclass
Expand Down Expand Up @@ -120,13 +120,24 @@ def to_string(self, s: str) -> str:
"Provide SQL for casting a column to string"
...

@abstractmethod
def concat(self, s: List[str]) -> str:
"Provide SQL for concatenating a bunch of column into a string"
...

@abstractmethod
def timestamp_value(self, t: DbTime) -> str:
"Provide SQL for the given timestamp value"
...

@abstractmethod
def md5_to_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"
...

@abstractmethod
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
"Provide SQL fragment for limit and offset inside a select"
...

@abstractmethod
Expand Down
15 changes: 13 additions & 2 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
from .base import DEFAULT_DATETIME_PRECISION, DEFAULT_NUMERIC_PRECISION

SESSION_TIME_ZONE = None # Changed by the tests
SESSION_TIME_ZONE = None # Changed by the tests


@import_helper("oracle")
def import_oracle():
Expand Down Expand Up @@ -89,6 +90,7 @@ def _parse_type(
regexps = {
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
r"TIMESTAMP\((\d)\)": Timestamp,
}
for regexp, t_cls in regexps.items():
m = re.match(regexp + "$", type_repr)
Expand All @@ -99,14 +101,23 @@ def _parse_type(
rounds=self.ROUNDS_ON_PREC_LOSS,
)

return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)
return super()._parse_type(
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
)

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
if offset:
raise NotImplementedError("No support for OFFSET in query")

return f"FETCH NEXT {limit} ROWS ONLY"

def concat(self, l: List[str]) -> str:
joined_exprs = " || ".join(l)
return f"({joined_exprs})"

def timestamp_value(self, t: DbTime) -> str:
return "timestamp '%s'" % t.isoformat(" ")

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
# Cast is necessary for correct MD5 (trimming not enough)
return f"CAST(TRIM({value}) AS VARCHAR(36))"
5 changes: 3 additions & 2 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .base import ThreadedDatabase, import_helper, ConnectError
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
SESSION_TIME_ZONE = None # Changed by the tests


@import_helper("postgresql")
def import_postgresql():
Expand Down Expand Up @@ -49,7 +50,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:

def create_connection(self):
if not self._args:
self._args['host'] = None # psycopg2 requires 1+ arguments
self._args["host"] = None # psycopg2 requires 1+ arguments

pg = import_postgresql()
try:
Expand Down
4 changes: 4 additions & 0 deletions data_diff/databases/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"{value}::decimal(38,{coltype.precision})")

def concat(self, l: List[str]) -> str:
joined_exprs = " || ".join(l)
return f"({joined_exprs})"

def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

Expand Down
12 changes: 5 additions & 7 deletions data_diff/databases/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
else:
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"

return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
return (
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
)

def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
Expand Down Expand Up @@ -96,9 +98,7 @@ def _parse_type(
if m:
datetime_precision = int(m.group(1))
return t_cls(
precision=datetime_precision
if datetime_precision is not None
else DEFAULT_DATETIME_PRECISION,
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

Expand All @@ -115,9 +115,7 @@ def _parse_type(
if m:
return n_cls()

return super()._parse_type(
table_path, col_name, type_repr, datetime_precision, numeric_precision
)
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM({value})"
7 changes: 3 additions & 4 deletions data_diff/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class Checksum(Sql):

def compile(self, c: Compiler):
if len(self.exprs) > 1:
compiled_exprs = ", ".join(f"coalesce({c.compile(expr)}, '<null>')" for expr in self.exprs)
expr = f"concat({compiled_exprs})"
compiled_exprs = [f"coalesce({c.compile(expr)}, '<null>')" for expr in self.exprs]
expr = c.database.concat(compiled_exprs)
else:
# No need to coalesce - safe to assume that key cannot be null
(expr,) = self.exprs
Expand Down Expand Up @@ -180,10 +180,9 @@ def compile(self, c: Compiler):
@dataclass
class Time(Sql):
time: datetime
column: Optional[SqlOrStr] = None

def compile(self, c: Compiler):
return "'%s'" % self.time.isoformat()
return c.database.timestamp_value(self.time)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ toml = "^0.10.2"
[tool.poetry.dev-dependencies]
parameterized = "*"
unittest-parallel = "*"
preql = "^0.2.16"
preql = "^0.2.17"
mysql-connector-python = "*"
databricks-sql-connector = "*"
snowflake-connector-python = "*"
Expand Down
1 change: 0 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,3 @@ def _drop_table_if_exists(conn, table):
conn.query(f"DROP TABLE IF EXISTS {table}", None)
if not isinstance(conn, (db.BigQuery, db.Databricks)):
conn.query("COMMIT", None)

21 changes: 13 additions & 8 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@
from data_diff.databases import postgresql, oracle
from data_diff.utils import number_to_human
from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix, _drop_table_if_exists
from .common import (
CONN_STRINGS,
N_SAMPLES,
N_THREADS,
BENCHMARK,
GIT_REVISION,
random_table_suffix,
_drop_table_if_exists,
)


CONNS = {k: db.connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}

CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = 'UTC'
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = "UTC"

DATABASE_TYPES = {
db.PostgreSQL: {
Expand Down Expand Up @@ -196,12 +204,10 @@
"INT",
"BIGINT",
],

# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/timestamp-type.html
"datetime": [
"TIMESTAMP",
],

# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/float-type.html
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/double-type.html
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/data-types/decimal-type.html
Expand All @@ -210,10 +216,9 @@
"DOUBLE",
"DECIMAL(6, 2)",
],

"uuid": [
"STRING",
]
],
},
db.Trino: {
"int": [
Expand Down Expand Up @@ -406,7 +411,7 @@ def __iter__(self):
) in source_type_categories.items(): # int, datetime, ..
for source_type in source_types:
for target_type in target_type_categories[type_category]:
if (CONNS.get(source_db, False) and CONNS.get(target_db, False)):
if CONNS.get(source_db, False) and CONNS.get(target_db, False):
type_pairs.append(
(
source_db,
Expand Down Expand Up @@ -480,7 +485,7 @@ def _insert_to_table(conn, table, values, type):
value = str(sample)
elif isinstance(sample, datetime) and isinstance(conn, (db.Presto, db.Oracle, db.Trino)):
value = f"timestamp '{sample}'"
elif isinstance(sample, datetime) and isinstance(conn, db.BigQuery) and type == 'datetime':
elif isinstance(sample, datetime) and isinstance(conn, db.BigQuery) and type == "datetime":
value = f"cast(timestamp '{sample}' as datetime)"
elif isinstance(sample, bytearray):
value = f"'{sample.decode()}'"
Expand Down
Loading