Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions data_diff/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def import_presto():
class ConnectError(Exception):
pass

class QueryError(Exception):
pass


def _one(seq):
(x,) = seq
Expand Down Expand Up @@ -156,9 +159,8 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str:

- Dates are expected in the format:
"YYYY-MM-DD HH:mm:SS.FFFFFF"
(number of F depends on coltype.precision)
Or if precision=0 then
"YYYY-MM-DD HH:mm:SS" (without the dot)

Rounded up/down according to coltype.rounds

"""
...
Expand Down Expand Up @@ -474,18 +476,26 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:


class Oracle(ThreadedDatabase):
ROUNDS_ON_PREC_LOSS = True

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
assert not port
self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw)
super().__init__(thread_count=thread_count)

def create_connection(self):
oracle = import_oracle()
self._oracle = import_oracle()
try:
return oracle.connect(**self.kwargs)
return self._oracle.connect(**self.kwargs)
except Exception as e:
raise ConnectError(*e.args) from e

def _query(self, sql_code: str):
try:
return super()._query(sql_code)
except self._oracle.DatabaseError as e:
raise QueryError(e)

def md5_to_int(self, s: str) -> str:
# standard_hash is faster than DBMS_CRYPTO.Hash
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
Expand All @@ -509,9 +519,7 @@ def select_table_schema(self, path: DbPath) -> str:

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, PrecisionType):
if coltype.precision == 0:
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS')"
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision or ''}')"
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
return self.to_string(f"{value}")

def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
Expand All @@ -524,7 +532,9 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION)
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS
)

return UnknownColType(type_repr)

Expand All @@ -533,6 +543,25 @@ class Redshift(Postgres):
def md5_to_int(self, s: str) -> str:
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, TemporalType):
if coltype.rounds:
timestamp = f"{value}::timestamp(6)"
# Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)"
# Get the milliseconds from timestamp.
ms = f"extract(ms from {timestamp})"
# Get the microseconds from timestamp, without the milliseconds!
us = f"extract(us from {timestamp})"
# epoch = Total time since epoch in microseconds.
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
else:
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"

return self.to_string(f"{value}")


class MsSQL(ThreadedDatabase):
"AKA sql-server"
Expand Down
58 changes: 37 additions & 21 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from contextlib import suppress
import unittest
import preql
import time
from data_diff import database as db
from data_diff.diff_tables import TableDiffer, TableSegment, split_space
from data_diff.diff_tables import TableDiffer, TableSegment
from parameterized import parameterized, parameterized_class
from .common import CONN_STRINGS, str_to_checksum
from .common import CONN_STRINGS
import logging

logging.getLogger("diff_tables").setLevel(logging.WARN)
Expand All @@ -18,11 +18,11 @@
"int": [127, -3, -9, 37, 15, 127],
"datetime_no_timezone": [
"2020-01-01 15:10:10",
"2020-01-01 9:9:9",
"2022-01-01 15:10:01.139",
"2022-01-01 15:10:02.020409",
"2022-01-01 15:10:03.003030",
"2022-01-01 15:10:05.009900",
"2020-02-01 9:9:9",
"2022-03-01 15:10:01.139",
"2022-04-01 15:10:02.020409",
"2022-05-01 15:10:03.003030",
"2022-06-01 15:10:05.009900",
],
"float": [0.0, 0.1, 0.10, 10.0, 100.98],
}
Expand Down Expand Up @@ -101,7 +101,7 @@
# "int",
],
"datetime_no_timezone": [
# "TIMESTAMP",
"TIMESTAMP",
],
# https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types
"float": [
Expand All @@ -115,9 +115,9 @@
# "int",
],
"datetime_no_timezone": [
# "timestamp",
# "timestamp(6)",
# "timestamp(9)",
"timestamp with local time zone",
"timestamp(6) with local time zone",
"timestamp(9) with local time zone",
],
"float": [
# "float",
Expand Down Expand Up @@ -179,14 +179,29 @@ def expand_params(testcase_func, param_num, param):


def _insert_to_table(conn, table, values):
insertion_query = f"INSERT INTO {table} (id, col) VALUES "
for j, sample in values:
insertion_query += f"({j}, '{sample}'),"

conn.query(insertion_query[0:-1], None)
insertion_query = f"INSERT INTO {table} (id, col) "

if isinstance(conn, db.Oracle):
selects = []
for j, sample in values:
selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" )
insertion_query += ' UNION ALL '.join(selects)
else:
insertion_query += ' VALUES '
for j, sample in values:
insertion_query += f"({j}, '{sample}'),"
insertion_query = insertion_query[0:-1]

conn.query(insertion_query, None)
if not isinstance(conn, db.BigQuery):
conn.query("COMMIT", None)

def _drop_table_if_exists(conn, table):
with suppress(db.QueryError):
if isinstance(conn, db.Oracle):
conn.query(f"DROP TABLE {table}", None)
else:
conn.query(f"DROP TABLE IF EXISTS {table}", None)

class TestDiffCrossDatabaseTables(unittest.TestCase):
@parameterized.expand(type_pairs, name_func=expand_params)
Expand All @@ -204,14 +219,14 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
src_table = src_conn.quote(".".join(src_table_path))
dst_table = dst_conn.quote(".".join(dst_table_path))

src_conn.query(f"DROP TABLE IF EXISTS {src_table}", None)
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type});", None)
_drop_table_if_exists(src_conn, src_table)
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None)
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))

values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list)

dst_conn.query(f"DROP TABLE IF EXISTS {dst_table}", None)
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type});", None)
_drop_table_if_exists(dst_conn, dst_table)
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)
_insert_to_table(dst_conn, dst_table, values_in_source)

self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), quote_columns=False)
Expand All @@ -235,3 +250,4 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego

duration = time.time() - start
# print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms")