Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
if not samples_by_row:
logger.warning(f"Table {table_path} is empty.")
return
raise ValueError(f"Table {table_path} is empty.")

samples_by_col = list(zip(*samples_by_row))

Expand Down
70 changes: 51 additions & 19 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import defaultdict
from typing import List, Tuple, Iterator, Optional
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed

from runtype import dataclass

Expand Down Expand Up @@ -315,17 +315,16 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
('-', columns) for items in table2 but not in table1
Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
"""
# Validate options
if self.bisection_factor >= self.bisection_threshold:
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")
if self.bisection_factor < 2:
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")

# Query and validate schema
table1, table2 = self._threaded_call("with_schema", [table1, table2])
self._validate_and_adjust_columns(table1, table2)

key_ranges = self._threaded_call("query_key_range", [table1, table2])
mins, maxs = zip(*key_ranges)

key_type = table1._schema[table1.key_column]
key_type2 = table2._schema[table2.key_column]
if not isinstance(key_type, IKey):
Expand All @@ -334,23 +333,42 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
raise NotImplementedError(f"Cannot use column of type {key_type2} as a key")
assert key_type.python_type is key_type2.python_type

# We add 1 because our ranges are exclusive of the end (like in Python)
try:
min_key = min(map(key_type.python_type, mins))
max_key = max(map(key_type.python_type, maxs)) + 1
except (TypeError, ValueError) as e:
raise type(e)(f"Cannot apply {key_type} to {mins}, {maxs}.") from e
# Query min/max values
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])

table1 = table1.new(min_key=min_key, max_key=max_key)
table2 = table2.new(min_key=min_key, max_key=max_key)
# Start with the first completed value, so we don't waste time waiting
min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges))

table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]

logger.info(
f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. "
f"key-range: {table1.min_key}..{table2.max_key}, "
f"size: {table2.max_key-table1.min_key}"
)

return self._bisect_and_diff_tables(table1, table2)
# Bisect (split) the table into segments, and diff them recursively.
yield from self._bisect_and_diff_tables(table1, table2)

# Now we check for the second min-max, to diff the portions we "missed".
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))

if min_key2 < min_key1:
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
yield from self._bisect_and_diff_tables(*pre_tables)

if max_key2 > max_key1:
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
yield from self._bisect_and_diff_tables(*post_tables)

def _parse_key_range_result(self, key_type, key_range):
mn, mx = key_range
cls = key_type.python_type
# We add 1 because our ranges are exclusive of the end (like in Python)
try:
return cls(mn), cls(mx) + 1
except (TypeError, ValueError) as e:
raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e

def _validate_and_adjust_columns(self, table1, table2):
for c in table1._relevant_columns:
Expand Down Expand Up @@ -474,12 +492,26 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
if checksum1 != checksum2:
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2))

def _thread_map(self, func, iter):
def _thread_map(self, func, iterable):
if not self.threaded:
return map(func, iterable)

with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
return task_pool.map(func, iterable)

def _threaded_call(self, func, iterable):
"Calls a method for each object in iterable."
return list(self._thread_map(methodcaller(func), iterable))

def _thread_as_completed(self, func, iterable):
if not self.threaded:
return map(func, iter)
return map(func, iterable)

task_pool = ThreadPoolExecutor(max_workers=self.max_threadpool_size)
return task_pool.map(func, iter)
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
futures = [task_pool.submit(func, item) for item in iterable]
for future in as_completed(futures):
yield future.result()

def _threaded_call(self, func, iter):
return list(self._thread_map(methodcaller(func), iter))
def _threaded_call_as_completed(self, func, iterable):
"Calls a method for each object in iterable. Returned in order of completion."
return self._thread_as_completed(methodcaller(func), iterable)
6 changes: 3 additions & 3 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def test_string_keys(self):
f"INSERT INTO {self.table_src} VALUES ('unexpected', '<-- this bad value should not break us')", None
)

self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))


@test_per_database
Expand Down Expand Up @@ -592,7 +592,7 @@ def setUp(self):

def test_right_table_empty(self):
differ = TableDiffer()
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))

def test_left_table_empty(self):
queries = [
Expand All @@ -605,4 +605,4 @@ def test_left_table_empty(self):
_commit(self.connection)

differ = TableDiffer()
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))