Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e8965fd

Browse files
committed
Joindiff: Fix stats collections
1 parent 245aeb6 commit e8965fd

File tree

7 files changed

+58
-54
lines changed

7 files changed

+58
-54
lines changed

data_diff/databases/database_types.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ class UnknownColType(ColType):
141141

142142

143143
class AbstractDialect(ABC):
144+
"""Dialect-dependent query expressions"""
145+
144146
name: str
145147

146148
@abstractmethod
@@ -177,56 +179,18 @@ def explain_as_text(self, query: str) -> str:
177179
"Provide SQL for explaining a query, returned in as table(varchar)"
178180
...
179181

180-
181-
class AbstractDatabase(AbstractDialect):
182182
@abstractmethod
183-
def timestamp_value(self, t: DbTime) -> str:
183+
def timestamp_value(self, t: datetime) -> str:
184184
"Provide SQL for the given timestamp value"
185185
...
186186

187-
@abstractmethod
188-
def md5_to_int(self, s: str) -> str:
189-
"Provide SQL for computing md5 and returning an int"
190-
...
191-
192-
@abstractmethod
193-
def _query(self, sql_code: str) -> list:
194-
"Send query to database and return result"
195-
...
196-
197-
@abstractmethod
198-
def select_table_schema(self, path: DbPath) -> str:
199-
"Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"
200-
...
201187

202-
@abstractmethod
203-
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
204-
"""Query the table for its schema for table in 'path', and return {column: tuple}
205-
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
206-
"""
207-
...
188+
class AbstractDatadiffDialect(ABC):
189+
"""Dialect-dependent query expressions, that are specific to data-diff"""
208190

209191
@abstractmethod
210-
def _process_table_schema(
211-
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
212-
):
213-
"""Process the result of query_table_schema().
214-
215-
Done in a separate step, to minimize the amount of processed columns.
216-
Needed because processing each column may:
217-
* throw errors and warnings
218-
* query the database to sample values
219-
220-
"""
221-
222-
@abstractmethod
223-
def parse_table_name(self, name: str) -> DbPath:
224-
"Parse the given table name into a DbPath"
225-
...
226-
227-
@abstractmethod
228-
def close(self):
229-
"Close connection(s) to the database instance. Querying will stop functioning."
192+
def md5_to_int(self, s: str) -> str:
193+
"Provide SQL for computing md5 and returning an int"
230194
...
231195

232196
@abstractmethod
@@ -294,6 +258,48 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
294258
return self.normalize_uuid(value, coltype)
295259
return self.to_string(value)
296260

261+
262+
class AbstractDatabase(AbstractDialect, AbstractDatadiffDialect):
263+
@abstractmethod
264+
def _query(self, sql_code: str) -> list:
265+
"Send query to database and return result"
266+
...
267+
268+
@abstractmethod
269+
def select_table_schema(self, path: DbPath) -> str:
270+
"Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"
271+
...
272+
273+
@abstractmethod
274+
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
275+
"""Query the table for its schema for table in 'path', and return {column: tuple}
276+
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
277+
"""
278+
...
279+
280+
@abstractmethod
281+
def _process_table_schema(
282+
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
283+
):
284+
"""Process the result of query_table_schema().
285+
286+
Done in a separate step, to minimize the amount of processed columns.
287+
Needed because processing each column may:
288+
* throw errors and warnings
289+
* query the database to sample values
290+
291+
"""
292+
293+
@abstractmethod
294+
def parse_table_name(self, name: str) -> DbPath:
295+
"Parse the given table name into a DbPath"
296+
...
297+
298+
@abstractmethod
299+
def close(self):
300+
"Close connection(s) to the database instance. Querying will stop functioning."
301+
...
302+
297303
@abstractmethod
298304
def _normalize_table_path(self, path: DbPath) -> DbPath:
299305
...

data_diff/databases/presto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def close(self):
8383
self._conn.close()
8484

8585
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
86-
# TODO
86+
# TODO rounds
8787
if coltype.rounds:
8888
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
8989
else:

data_diff/joindiff_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from runtype import dataclass
1212

13-
from data_diff.databases.database_types import DbPath, Schema
13+
from data_diff.databases.database_types import DbPath, NumericType, Schema
1414
from data_diff.databases.base import QueryError
1515

1616

@@ -273,7 +273,7 @@ def _collect_stats(self, i, table):
273273
f"max_{c}": max_(this[c]),
274274
}
275275
for c in table._relevant_columns
276-
if c == "id" # TODO just if the right type
276+
if isinstance(table._schema[c], NumericType)
277277
)
278278
col_exprs["count"] = Count()
279279

data_diff/table_segment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def count_and_checksum(self) -> Tuple[int, int]:
177177
def query_key_range(self) -> Tuple[int, int]:
178178
"""Query database for minimum and maximum key. This is used for setting the initial bounds."""
179179
# Normalizes the result (needed for UUIDs) after the min/max computation
180-
# TODO better error if there is no schema
181180
(k,) = self.key_columns
182181
select = self._make_select().select(
183182
ApplyFuncAndNormalizeAsString(this[k], min_),

tests/test_diff_tables.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,6 @@ def setUp(self):
485485
self.new_uuid = uuid.uuid1(32132131)
486486
queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_uuid}', 'This one is different')")
487487

488-
# TODO test unexpected values?
489-
490488
for query in queries:
491489
self.connection.query(query, None)
492490

@@ -542,8 +540,6 @@ def setUp(self):
542540
self.new_alphanum = "aBcDeFgHiJ"
543541
queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')")
544542

545-
# TODO test unexpected values?
546-
547543
for query in queries:
548544
self.connection.query(query, None)
549545

@@ -594,8 +590,6 @@ def setUp(self):
594590
self.new_alphanum = "aBcDeFgHiJ"
595591
queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')")
596592

597-
# TODO test unexpected values?
598-
599593
for query in queries:
600594
self.connection.query(query, None)
601595

tests/test_joindiff.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def test_diff_small_tables(self):
137137
self.assertEqual(expected, diff)
138138
self.assertEqual(2, self.differ.stats["table1_count"])
139139
self.assertEqual(1, self.differ.stats["table2_count"])
140+
self.assertEqual(3, self.differ.stats["table1_sum_id"])
141+
self.assertEqual(1, self.differ.stats["table2_sum_id"])
140142

141143
# Test materialize
142144
materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}")

tests/test_query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cmath import exp
1+
from datetime import datetime
22
from typing import List, Optional
33
import unittest
44
from data_diff.databases.database_types import AbstractDialect, CaseInsensitiveDict, CaseSensitiveDict
@@ -35,6 +35,9 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None
3535
def explain_as_text(self, query: str) -> str:
3636
return f"explain {query}"
3737

38+
def timestamp_value(self, t: datetime) -> str:
39+
return f"timestamp '{t}'"
40+
3841

3942
class TestQuery(unittest.TestCase):
4043
def setUp(self):

0 commit comments

Comments
 (0)