Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 18b16be

Browse files
authored
Merge pull request #70 from datafold/normalize_more_dbs2
Fixed oracle & redshift support (normalize-fields)
2 parents 311deaa + 483d284 commit 18b16be

File tree

2 files changed

+75
-30
lines changed

2 files changed

+75
-30
lines changed

data_diff/database.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def import_presto():
6363
class ConnectError(Exception):
6464
pass
6565

66+
class QueryError(Exception):
67+
pass
68+
6669

6770
def _one(seq):
6871
(x,) = seq
@@ -156,9 +159,8 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str:
156159
157160
- Dates are expected in the format:
158161
"YYYY-MM-DD HH:mm:SS.FFFFFF"
159-
(number of F depends on coltype.precision)
160-
Or if precision=0 then
161-
"YYYY-MM-DD HH:mm:SS" (without the dot)
162+
163+
Rounded up/down according to coltype.rounds
162164
163165
"""
164166
...
@@ -474,18 +476,26 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
474476

475477

476478
class Oracle(ThreadedDatabase):
479+
ROUNDS_ON_PREC_LOSS = True
480+
477481
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
478482
assert not port
479483
self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw)
480484
super().__init__(thread_count=thread_count)
481485

482486
def create_connection(self):
483-
oracle = import_oracle()
487+
self._oracle = import_oracle()
484488
try:
485-
return oracle.connect(**self.kwargs)
489+
return self._oracle.connect(**self.kwargs)
486490
except Exception as e:
487491
raise ConnectError(*e.args) from e
488492

493+
def _query(self, sql_code: str):
494+
try:
495+
return super()._query(sql_code)
496+
except self._oracle.DatabaseError as e:
497+
raise QueryError(e)
498+
489499
def md5_to_int(self, s: str) -> str:
490500
# standard_hash is faster than DBMS_CRYPTO.Hash
491501
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
@@ -509,9 +519,7 @@ def select_table_schema(self, path: DbPath) -> str:
509519

510520
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
511521
if isinstance(coltype, PrecisionType):
512-
if coltype.precision == 0:
513-
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS')"
514-
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision or ''}')"
522+
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
515523
return self.to_string(f"{value}")
516524

517525
def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
@@ -524,7 +532,9 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr
524532
m = re.match(regexp + "$", type_repr)
525533
if m:
526534
datetime_precision = int(m.group(1))
527-
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION)
535+
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION,
536+
rounds=self.ROUNDS_ON_PREC_LOSS
537+
)
528538

529539
return UnknownColType(type_repr)
530540

@@ -533,6 +543,25 @@ class Redshift(Postgres):
533543
def md5_to_int(self, s: str) -> str:
534544
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"
535545

546+
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
547+
if isinstance(coltype, TemporalType):
548+
if coltype.rounds:
549+
timestamp = f"{value}::timestamp(6)"
550+
# Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
551+
secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)"
552+
# Get the milliseconds from timestamp.
553+
ms = f"extract(ms from {timestamp})"
554+
# Get the microseconds from timestamp, without the milliseconds!
555+
us = f"extract(us from {timestamp})"
556+
# epoch = Total time since epoch in microseconds.
557+
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
558+
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
559+
else:
560+
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
561+
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
562+
563+
return self.to_string(f"{value}")
564+
536565

537566
class MsSQL(ThreadedDatabase):
538567
"AKA sql-server"

tests/test_database_types.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from contextlib import suppress
12
import unittest
2-
import preql
33
import time
44
from data_diff import database as db
5-
from data_diff.diff_tables import TableDiffer, TableSegment, split_space
5+
from data_diff.diff_tables import TableDiffer, TableSegment
66
from parameterized import parameterized, parameterized_class
7-
from .common import CONN_STRINGS, str_to_checksum
7+
from .common import CONN_STRINGS
88
import logging
99

1010
logging.getLogger("diff_tables").setLevel(logging.WARN)
@@ -18,11 +18,11 @@
1818
"int": [127, -3, -9, 37, 15, 127],
1919
"datetime_no_timezone": [
2020
"2020-01-01 15:10:10",
21-
"2020-01-01 9:9:9",
22-
"2022-01-01 15:10:01.139",
23-
"2022-01-01 15:10:02.020409",
24-
"2022-01-01 15:10:03.003030",
25-
"2022-01-01 15:10:05.009900",
21+
"2020-02-01 9:9:9",
22+
"2022-03-01 15:10:01.139",
23+
"2022-04-01 15:10:02.020409",
24+
"2022-05-01 15:10:03.003030",
25+
"2022-06-01 15:10:05.009900",
2626
],
2727
"float": [0.0, 0.1, 0.10, 10.0, 100.98],
2828
}
@@ -101,7 +101,7 @@
101101
# "int",
102102
],
103103
"datetime_no_timezone": [
104-
# "TIMESTAMP",
104+
"TIMESTAMP",
105105
],
106106
# https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types
107107
"float": [
@@ -115,9 +115,9 @@
115115
# "int",
116116
],
117117
"datetime_no_timezone": [
118-
# "timestamp",
119-
# "timestamp(6)",
120-
# "timestamp(9)",
118+
"timestamp with local time zone",
119+
"timestamp(6) with local time zone",
120+
"timestamp(9) with local time zone",
121121
],
122122
"float": [
123123
# "float",
@@ -179,14 +179,29 @@ def expand_params(testcase_func, param_num, param):
179179

180180

181181
def _insert_to_table(conn, table, values):
182-
insertion_query = f"INSERT INTO {table} (id, col) VALUES "
183-
for j, sample in values:
184-
insertion_query += f"({j}, '{sample}'),"
185-
186-
conn.query(insertion_query[0:-1], None)
182+
insertion_query = f"INSERT INTO {table} (id, col) "
183+
184+
if isinstance(conn, db.Oracle):
185+
selects = []
186+
for j, sample in values:
187+
selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" )
188+
insertion_query += ' UNION ALL '.join(selects)
189+
else:
190+
insertion_query += ' VALUES '
191+
for j, sample in values:
192+
insertion_query += f"({j}, '{sample}'),"
193+
insertion_query = insertion_query[0:-1]
194+
195+
conn.query(insertion_query, None)
187196
if not isinstance(conn, db.BigQuery):
188197
conn.query("COMMIT", None)
189198

199+
def _drop_table_if_exists(conn, table):
200+
with suppress(db.QueryError):
201+
if isinstance(conn, db.Oracle):
202+
conn.query(f"DROP TABLE {table}", None)
203+
else:
204+
conn.query(f"DROP TABLE IF EXISTS {table}", None)
190205

191206
class TestDiffCrossDatabaseTables(unittest.TestCase):
192207
@parameterized.expand(type_pairs, name_func=expand_params)
@@ -204,14 +219,14 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
204219
src_table = src_conn.quote(".".join(src_table_path))
205220
dst_table = dst_conn.quote(".".join(dst_table_path))
206221

207-
src_conn.query(f"DROP TABLE IF EXISTS {src_table}", None)
208-
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type});", None)
222+
_drop_table_if_exists(src_conn, src_table)
223+
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None)
209224
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))
210225

211226
values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list)
212227

213-
dst_conn.query(f"DROP TABLE IF EXISTS {dst_table}", None)
214-
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type});", None)
228+
_drop_table_if_exists(dst_conn, dst_table)
229+
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)
215230
_insert_to_table(dst_conn, dst_table, values_in_source)
216231

217232
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False)
@@ -235,3 +250,4 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
235250

236251
duration = time.time() - start
237252
# print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms")
253+

0 commit comments

Comments
 (0)