Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions data_diff/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str:

- Dates are expected in the format:
"YYYY-MM-DD HH:mm:SS.FFFFFF"
(number of F depends on coltype.precision)
Or if precision=0 then
"YYYY-MM-DD HH:mm:SS" (without the dot)

Rounded up/down according to coltype.rounds

"""
...
Expand Down Expand Up @@ -474,6 +473,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:


class Oracle(ThreadedDatabase):
ROUNDS_ON_PREC_LOSS = True

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
assert not port
self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw)
Expand Down Expand Up @@ -509,9 +510,7 @@ def select_table_schema(self, path: DbPath) -> str:

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, PrecisionType):
if coltype.precision == 0:
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS')"
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision or ''}')"
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
return self.to_string(f"{value}")

def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
Expand All @@ -524,7 +523,9 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION)
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS
)

return UnknownColType(type_repr)

Expand All @@ -533,6 +534,21 @@ class Redshift(Postgres):
def md5_to_int(self, s: str) -> str:
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, TemporalType):
if coltype.rounds:
timestamp = f"{value}::timestamp(6)"
secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)"
ms = f"extract(ms from {timestamp})"
us = f"extract(us from {timestamp})"
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
else:
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"

return self.to_string(f"{value}")


class MsSQL(ThreadedDatabase):
"AKA sql-server"
Expand Down
58 changes: 37 additions & 21 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import unittest
import preql
import time
from data_diff import database as db
from data_diff.diff_tables import TableDiffer, TableSegment, split_space
from data_diff.diff_tables import TableDiffer, TableSegment
from parameterized import parameterized, parameterized_class
from .common import CONN_STRINGS, str_to_checksum
from .common import CONN_STRINGS
import logging

logging.getLogger("diff_tables").setLevel(logging.WARN)
Expand All @@ -18,11 +17,11 @@
"int": [127, -3, -9, 37, 15, 127],
"datetime_no_timezone": [
"2020-01-01 15:10:10",
"2020-01-01 9:9:9",
"2022-01-01 15:10:01.139",
"2022-01-01 15:10:02.020409",
"2022-01-01 15:10:03.003030",
"2022-01-01 15:10:05.009900",
"2020-02-01 9:9:9",
"2022-03-01 15:10:01.139",
"2022-04-01 15:10:02.020409",
"2022-05-01 15:10:03.003030",
"2022-06-01 15:10:05.009900",
],
"float": [0.0, 0.1, 0.10, 10.0, 100.98],
}
Expand Down Expand Up @@ -101,7 +100,7 @@
# "int",
],
"datetime_no_timezone": [
# "TIMESTAMP",
"TIMESTAMP",
],
# https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types
"float": [
Expand All @@ -115,9 +114,9 @@
# "int",
],
"datetime_no_timezone": [
# "timestamp",
# "timestamp(6)",
# "timestamp(9)",
"timestamp with local time zone",
"timestamp(6) with local time zone",
"timestamp(9) with local time zone",
],
"float": [
# "float",
Expand Down Expand Up @@ -179,11 +178,20 @@ def expand_params(testcase_func, param_num, param):


def _insert_to_table(conn, table, values):
insertion_query = f"INSERT INTO {table} (id, col) VALUES "
for j, sample in values:
insertion_query += f"({j}, '{sample}'),"

conn.query(insertion_query[0:-1], None)
insertion_query = f"INSERT INTO {table} (id, col) "

if isinstance(conn, db.Oracle):
selects = []
for j, sample in values:
selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" )
insertion_query += ' UNION ALL '.join(selects)
else:
insertion_query += ' VALUES '
for j, sample in values:
insertion_query += f"({j}, '{sample}'),"
insertion_query = insertion_query[0:-1]

conn.query(insertion_query, None)
if not isinstance(conn, db.BigQuery):
conn.query("COMMIT", None)

Expand All @@ -204,14 +212,22 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
src_table = src_conn.quote(".".join(src_table_path))
dst_table = dst_conn.quote(".".join(dst_table_path))

src_conn.query(f"DROP TABLE IF EXISTS {src_table}", None)
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type});", None)
if isinstance(src_conn, db.Oracle):
src_conn.query(f"DROP TABLE {src_table}", None)
else:
src_conn.query(f"DROP TABLE IF EXISTS {src_table}", None)

src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None)
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))

values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list)

dst_conn.query(f"DROP TABLE IF EXISTS {dst_table}", None)
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type});", None)
if isinstance(dst_conn, db.Oracle):
dst_conn.query(f"DROP TABLE {dst_table}", None)
else:
dst_conn.query(f"DROP TABLE IF EXISTS {dst_table}", None)

dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)
_insert_to_table(dst_conn, dst_table, values_in_source)

self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), quote_columns=False)
Expand Down