Skip to content

Commit f2b62e8

Browse files
authored
Merge pull request datafold#224 from datafold/issue221
Bugfix in TableSegment: Sampling now respects the 'where' clause (issue datafold#221)
2 parents c585550 + 940ae6b commit f2b62e8

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

data_diff/databases/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,25 +187,30 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
187187
assert len(d) == len(rows)
188188
return d
189189

190-
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
190+
def _process_table_schema(
191+
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
192+
):
191193
accept = {i.lower() for i in filter_columns}
192194

193195
col_dict = {row[0]: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept}
194196

195-
self._refine_coltypes(path, col_dict)
197+
self._refine_coltypes(path, col_dict, where)
196198

197199
# Return a dict of form {name: type} after normalization
198200
return col_dict
199201

200-
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
201-
"Refine the types in the column dict, by querying the database for a sample of their values"
202+
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None):
203+
"""Refine the types in the column dict, by querying the database for a sample of their values
204+
205+
'where' restricts the rows to be sampled.
206+
"""
202207

203208
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
204209
if not text_columns:
205210
return
206211

207212
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
208-
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
213+
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16, where=where and [where]), list)
209214
if not samples_by_row:
210215
raise ValueError(f"Table {table_path} is empty.")
211216

data_diff/databases/database_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
177177
...
178178

179179
@abstractmethod
180-
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
180+
def _process_table_schema(
181+
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
182+
):
181183
"""Process the result of query_table_schema().
182184
183185
Done in a separate step, to minimize the amount of processed columns.

data_diff/databases/databricks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
8383
assert len(d) == len(rows)
8484
return d
8585

86-
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
86+
def _process_table_schema(
87+
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
88+
):
8789
accept = {i.lower() for i in filter_columns}
8890
rows = [row for name, row in raw_schema.items() if name.lower() in accept]
8991

@@ -115,7 +117,7 @@ def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filt
115117

116118
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows}
117119

118-
self._refine_coltypes(path, col_dict)
120+
self._refine_coltypes(path, col_dict, where)
119121
return col_dict
120122

121123
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:

data_diff/table_segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _normalize_column(self, name: str, template: str = None) -> str:
111111
return self.database.normalize_value_by_type(col, col_type)
112112

113113
def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
114-
schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns)
114+
schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns, self.where)
115115
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
116116

117117
def with_schema(self) -> "TableSegment":

tests/test_diff_tables.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,16 @@ def test_string_keys(self):
443443

444444
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
445445

446+
def test_where_sampling(self):
447+
a = self.a.replace(where="1=1")
448+
449+
differ = TableDiffer()
450+
diff = list(differ.diff_tables(a, self.b))
451+
self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))])
452+
453+
a_empty = self.a.replace(where="1=0")
454+
self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b))
455+
446456

447457
@test_per_database
448458
class TestAlphanumericKeys(TestPerDatabase):

0 commit comments

Comments
 (0)