Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 38c56cf

Browse files
authored
Merge pull request #205 from datafold/aug12
Refactor - nicer regexp parsing; Trino now inherits from Presto
2 parents e6a1b1c + 485162c commit 38c56cf

File tree

4 files changed

+29
-121
lines changed

4 files changed

+29
-121
lines changed

data_diff/databases/oracle.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import re
22

3+
from ..utils import match_regexps
4+
35
from .database_types import *
46
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
57
from .base import DEFAULT_DATETIME_PRECISION, TIMESTAMP_PRECISION_POS
@@ -99,14 +101,10 @@ def _parse_type(
99101
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
100102
r"TIMESTAMP\((\d)\)": Timestamp,
101103
}
102-
for regexp, t_cls in regexps.items():
103-
m = re.match(regexp + "$", type_repr)
104-
if m:
105-
datetime_precision = int(m.group(1))
106-
return t_cls(
107-
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
108-
rounds=self.ROUNDS_ON_PREC_LOSS,
109-
)
104+
105+
for m, t_cls in match_regexps(regexps, type_repr):
106+
precision = int(m.group(1))
107+
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)
110108

111109
return super()._parse_type(
112110
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale

data_diff/databases/presto.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import re
22

3+
from ..utils import match_regexps
4+
35
from .database_types import *
46
from .base import Database, import_helper, _query_conn
57
from .base import (
@@ -94,27 +96,18 @@ def _parse_type(
9496
r"timestamp\((\d)\)": Timestamp,
9597
r"timestamp\((\d)\) with time zone": TimestampTZ,
9698
}
97-
for regexp, t_cls in timestamp_regexps.items():
98-
m = re.match(regexp + "$", type_repr)
99-
if m:
100-
datetime_precision = int(m.group(1))
101-
return t_cls(
102-
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
103-
rounds=self.ROUNDS_ON_PREC_LOSS,
104-
)
99+
for m, t_cls in match_regexps(timestamp_regexps, type_repr):
100+
precision = int(m.group(1))
101+
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)
105102

106103
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
107-
for regexp, n_cls in number_regexps.items():
108-
m = re.match(regexp + "$", type_repr)
109-
if m:
110-
prec, scale = map(int, m.groups())
111-
return n_cls(scale)
104+
for m, n_cls in match_regexps(number_regexps, type_repr):
105+
_prec, scale = map(int, m.groups())
106+
return n_cls(scale)
112107

113108
string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
114-
for regexp, n_cls in string_regexps.items():
115-
m = re.match(regexp + "$", type_repr)
116-
if m:
117-
return n_cls()
109+
for m, n_cls in match_regexps(string_regexps, type_repr):
110+
return n_cls()
118111

119112
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)
120113

data_diff/databases/trino.py

Lines changed: 4 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
import re
2-
31
from .database_types import *
4-
from .base import Database, import_helper
5-
from .base import (
6-
MD5_HEXDIGITS,
7-
CHECKSUM_HEXDIGITS,
8-
TIMESTAMP_PRECISION_POS,
9-
DEFAULT_DATETIME_PRECISION,
10-
)
2+
from .presto import Presto
3+
from .base import import_helper
4+
from .base import TIMESTAMP_PRECISION_POS
115

126

137
@import_helper("trino")
@@ -17,49 +11,12 @@ def import_trino():
1711
return trino
1812

1913

20-
class Trino(Database):
21-
default_schema = "public"
22-
TYPE_CLASSES = {
23-
# Timestamps
24-
"timestamp with time zone": TimestampTZ,
25-
"timestamp without time zone": Timestamp,
26-
"timestamp": Timestamp,
27-
# Numbers
28-
"integer": Integer,
29-
"bigint": Integer,
30-
"real": Float,
31-
"double": Float,
32-
# Text
33-
"varchar": Text,
34-
}
35-
ROUNDS_ON_PREC_LOSS = True
36-
14+
class Trino(Presto):
3715
def __init__(self, **kw):
3816
trino = import_trino()
3917

4018
self._conn = trino.dbapi.connect(**kw)
4119

42-
def quote(self, s: str):
43-
return f'"{s}"'
44-
45-
def md5_to_int(self, s: str) -> str:
46-
return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))"
47-
48-
def to_string(self, s: str):
49-
return f"cast({s} as varchar)"
50-
51-
def _query(self, sql_code: str) -> list:
52-
"""Uses the standard SQL cursor interface"""
53-
c = self._conn.cursor()
54-
c.execute(sql_code)
55-
if sql_code.lower().startswith("select"):
56-
return c.fetchall()
57-
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
58-
return c.fetchone()
59-
60-
def close(self):
61-
self._conn.close()
62-
6320
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6421
if coltype.rounds:
6522
s = f"date_format(cast({value} as timestamp({coltype.precision})), '%Y-%m-%d %H:%i:%S.%f')"
@@ -70,52 +27,5 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
7027
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
7128
)
7229

73-
def normalize_number(self, value: str, coltype: FractionalType) -> str:
74-
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
75-
76-
def select_table_schema(self, path: DbPath) -> str:
77-
schema, table = self._normalize_table_path(path)
78-
79-
return (
80-
f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS "
81-
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
82-
)
83-
84-
def _parse_type(
85-
self,
86-
table_path: DbPath,
87-
col_name: str,
88-
type_repr: str,
89-
datetime_precision: int = None,
90-
numeric_precision: int = None,
91-
) -> ColType:
92-
timestamp_regexps = {
93-
r"timestamp\((\d)\)": Timestamp,
94-
r"timestamp\((\d)\) with time zone": TimestampTZ,
95-
}
96-
for regexp, t_cls in timestamp_regexps.items():
97-
m = re.match(regexp + "$", type_repr)
98-
if m:
99-
datetime_precision = int(m.group(1))
100-
return t_cls(
101-
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
102-
rounds=self.ROUNDS_ON_PREC_LOSS,
103-
)
104-
105-
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
106-
for regexp, n_cls in number_regexps.items():
107-
m = re.match(regexp + "$", type_repr)
108-
if m:
109-
prec, scale = map(int, m.groups())
110-
return n_cls(scale)
111-
112-
string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
113-
for regexp, n_cls in string_regexps.items():
114-
m = re.match(regexp + "$", type_repr)
115-
if m:
116-
return n_cls()
117-
118-
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)
119-
12030
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
12131
return f"TRIM({value})"

data_diff/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
import math
3-
from typing import Iterable, Tuple, Union, Any, Sequence
3+
from typing import Iterable, Tuple, Union, Any, Sequence, Dict
44
from typing import TypeVar, Generic
55
from abc import ABC, abstractmethod
66
from urllib.parse import urlparse
@@ -225,7 +225,7 @@ def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]:
225225

226226

227227
def accumulate(iterable, func=operator.add, *, initial=None):
228-
'Return running totals'
228+
"Return running totals"
229229
# Taken from https://docs.python.org/3/library/itertools.html#itertools.accumulate, to backport 'initial' to 3.7
230230
it = iter(iterable)
231231
total = initial
@@ -238,3 +238,10 @@ def accumulate(iterable, func=operator.add, *, initial=None):
238238
for element in it:
239239
total = func(total, element)
240240
yield total
241+
242+
243+
def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]:
244+
for regexp, v in regexps.items():
245+
m = re.match(regexp + "$", s)
246+
if m:
247+
yield m, v

0 commit comments

Comments
 (0)