Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 8afd597

Browse files
committed
Initial support for Alphanumeric IDs + tests (Issue #59)
1 parent e357473 commit 8afd597

File tree

7 files changed

+200
-31
lines changed

7 files changed

+200
-31
lines changed

data_diff/databases/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Float,
1717
ColType_UUID,
1818
Native_UUID,
19+
String_Alphanum,
1920
String_UUID,
2021
TemporalType,
2122
UnknownColType,
@@ -221,6 +222,23 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
221222
assert col_name in col_dict
222223
col_dict[col_name] = String_UUID()
223224

225+
alphanum_samples = [s for s in samples if s and String_Alphanum.test_value(s)]
226+
if alphanum_samples:
227+
if len(alphanum_samples) != len(samples):
228+
logger.warning(
229+
f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}, disabling Alphanum support."
230+
)
231+
else:
232+
assert col_name in col_dict
233+
lens = set(map(len, alphanum_samples))
234+
if len(lens) > 1:
235+
logger.warning(
236+
f"Mixed Alphanum lengths detected in column {'.'.join(table_path)}.{col_name}, disabling Alphanum support."
237+
)
238+
else:
239+
(length,) = lens
240+
col_dict[col_name] = String_Alphanum(length=length)
241+
224242
# @lru_cache()
225243
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
226244
# return self.query_table_schema(path)

data_diff/databases/database_types.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from runtype import dataclass
77

8-
from data_diff.utils import ArithUUID
8+
from data_diff.utils import ArithAlphanumeric, ArithUUID, ArithString
99

1010

1111
DbPath = Tuple[str, ...]
12-
DbKey = Union[int, str, bytes, ArithUUID]
12+
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
1313
DbTime = datetime
1414

1515

@@ -18,11 +18,6 @@ class ColType:
1818
pass
1919

2020

21-
class IKey(ABC):
22-
"Interface for ColType, for using a column as a key in data-diff"
23-
python_type: type
24-
25-
2621
@dataclass
2722
class PrecisionType(ColType):
2823
precision: int
@@ -63,6 +58,9 @@ class IKey(ABC):
6358
"Interface for ColType, for using a column as a key in data-diff"
6459
python_type: type
6560

61+
def make_value(self, value):
62+
return self.python_type(value)
63+
6664

6765
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
6866
@property
@@ -80,6 +78,10 @@ class ColType_UUID(ColType, IKey):
8078
python_type = ArithUUID
8179

8280

81+
class ColType_Alphanum(ColType, IKey):
82+
python_type = ArithAlphanumeric
83+
84+
8385
class Native_UUID(ColType_UUID):
8486
pass
8587

@@ -88,6 +90,24 @@ class String_UUID(StringType, ColType_UUID):
8890
pass
8991

9092

93+
@dataclass
94+
class String_Alphanum(StringType, ColType_Alphanum):
95+
length: int
96+
97+
@staticmethod
98+
def test_value(value: str) -> bool:
99+
try:
100+
ArithAlphanumeric(value)
101+
return True
102+
except ValueError:
103+
return False
104+
105+
def make_value(self, value):
106+
if len(value) != self.length:
107+
raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.")
108+
return self.python_type(value, max_len=self.length)
109+
110+
91111
@dataclass
92112
class Text(StringType):
93113
supported = False

data_diff/diff_tables.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .utils import safezip, split_space
1717
from .databases.base import Database
1818
from .databases.database_types import (
19-
ArithUUID,
19+
ArithString,
2020
IKey,
2121
Native_UUID,
2222
NumericType,
@@ -175,10 +175,10 @@ def get_values(self) -> list:
175175
def choose_checkpoints(self, count: int) -> List[DbKey]:
176176
"Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)"
177177
assert self.is_bounded
178-
if isinstance(self.min_key, ArithUUID):
178+
if isinstance(self.min_key, ArithString):
179+
assert type(self.min_key) is type(self.max_key)
179180
checkpoints = split_space(self.min_key.int, self.max_key.int, count)
180-
assert isinstance(self.max_key, ArithUUID)
181-
return [ArithUUID(int=i) for i in checkpoints]
181+
return [self.min_key.new(int=i) for i in checkpoints]
182182

183183
return split_space(self.min_key, self.max_key, count)
184184

@@ -363,7 +363,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
363363

364364
def _parse_key_range_result(self, key_type, key_range):
365365
mn, mx = key_range
366-
cls = key_type.python_type
366+
cls = key_type.make_value
367367
# We add 1 because our ranges are exclusive of the end (like in Python)
368368
try:
369369
return cls(mn), cls(mx) + 1

data_diff/sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .utils import join_iter
1010

11-
from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime, ArithUUID
11+
from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime, ArithString
1212

1313

1414
class Sql:
@@ -69,7 +69,7 @@ def compile(self, c: Compiler):
6969
return "b'%s'" % self.value.decode()
7070
elif isinstance(self.value, str):
7171
return "'%s'" % self.value
72-
elif isinstance(self.value, ArithUUID):
72+
elif isinstance(self.value, ArithString):
7373
return "'%s'" % self.value
7474
return str(self.value)
7575

data_diff/utils.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from typing import Union, Any
55
from uuid import UUID
6+
import string
7+
8+
alphanums = string.digits + string.ascii_lowercase
69

710

811
def safezip(*args):
@@ -16,24 +19,99 @@ def split_space(start, end, count):
1619
return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1]
1720

1821

19-
class ArithUUID(UUID):
22+
class ArithString:
23+
@classmethod
24+
def new(cls, *args, **kw):
25+
return cls(*args, **kw)
26+
27+
28+
class ArithUUID(UUID, ArithString):
2029
"A UUID that supports basic arithmetic (add, sub)"
2130

31+
def __int__(self):
32+
return self.int
33+
2234
def __add__(self, other: Union[UUID, int]):
2335
if isinstance(other, int):
24-
return type(self)(int=self.int + other)
36+
return self.new(int=self.int + other)
2537
return NotImplemented
2638

2739
def __sub__(self, other: Union[UUID, int]):
2840
if isinstance(other, int):
29-
return type(self)(int=self.int - other)
41+
return self.new(int=self.int - other)
3042
elif isinstance(other, UUID):
3143
return self.int - other.int
3244
return NotImplemented
3345

46+
47+
def numberToBase(num, base):
48+
digits = []
49+
while num > 0:
50+
num, remainder = divmod(num, base)
51+
digits.append(remainder)
52+
return "".join(alphanums[i] for i in digits[::-1])
53+
54+
55+
class ArithAlphanumeric(ArithString):
56+
def __init__(self, str: str = None, int: int = None, max_len=None):
57+
if str is None:
58+
str = numberToBase(int, len(alphanums))
59+
else:
60+
assert int is None
61+
62+
if max_len and len(str) > max_len:
63+
raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}")
64+
65+
self._str = str
66+
self._max_len = max_len
67+
68+
@property
69+
def int(self):
70+
return int(self._str, len(alphanums))
71+
72+
def __str__(self):
73+
s = self._str
74+
if self._max_len:
75+
s = s.rjust(self._max_len, "0")
76+
return s
77+
78+
def __len__(self):
79+
return len(self._str)
80+
3481
def __int__(self):
3582
return self.int
3683

84+
def __repr__(self):
85+
return f'alphanum"{self._str}"'
86+
87+
def __add__(self, other: "Union[ArithAlphanumeric, int]"):
88+
if isinstance(other, int):
89+
res = self.new(int=self.int + other)
90+
if len(str(res)) != len(self):
91+
raise ValueError("Overflow error when adding to alphanumeric")
92+
return res
93+
return NotImplemented
94+
95+
def __sub__(self, other: "Union[ArithAlphanumeric, int]"):
96+
if isinstance(other, int):
97+
return type(self)(int=self.int - other)
98+
elif isinstance(other, ArithAlphanumeric):
99+
return self.int - other.int
100+
return NotImplemented
101+
102+
def __ge__(self, other):
103+
if not isinstance(other, type(self)):
104+
return NotImplemented
105+
return self.int >= other.int
106+
107+
def __lt__(self, other):
108+
if not isinstance(other, type(self)):
109+
return NotImplemented
110+
return self.int < other.int
111+
112+
def new(self, *args, **kw):
113+
return type(self)(*args, **kw, max_len=self._max_len)
114+
37115

38116
def is_uuid(u):
39117
try:
@@ -57,23 +135,24 @@ def number_to_human(n):
57135
def _join_if_any(sym, args):
58136
args = list(args)
59137
if not args:
60-
return ''
138+
return ""
61139
return sym.join(str(a) for a in args if a)
62140

63-
def remove_password_from_url(url: str, replace_with: str="***") -> str:
141+
142+
def remove_password_from_url(url: str, replace_with: str = "***") -> str:
64143
parsed = urlparse(url)
65-
account = parsed.username or ''
144+
account = parsed.username or ""
66145
if parsed.password:
67-
account += ':' + replace_with
146+
account += ":" + replace_with
68147
host = _join_if_any(":", filter(None, [parsed.hostname, parsed.port]))
69148
netloc = _join_if_any("@", filter(None, [account, host]))
70149
replaced = parsed._replace(netloc=netloc)
71150
return replaced.geturl()
72151

152+
73153
def join_iter(joiner: Any, iterable: iter) -> iter:
74154
it = iter(iterable)
75155
yield next(it)
76156
for i in it:
77157
yield joiner
78158
yield i
79-

tests/test_config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ def test_basic(self):
4646
def test_remove_password(self):
4747
replace_with = "*****"
4848
urls = [
49-
'd://host/',
50-
'd://host:123/',
51-
'd://user@host:123/',
52-
'd://user:PASS@host:123/',
53-
'd://:PASS@host:123/',
54-
'd://:PASS@host:123/path',
55-
'd://:PASS@host:123/path?whatever#blabla',
49+
"d://host/",
50+
"d://host:123/",
51+
"d://user@host:123/",
52+
"d://user:PASS@host:123/",
53+
"d://:PASS@host:123/",
54+
"d://:PASS@host:123/path",
55+
"d://:PASS@host:123/path?whatever#blabla",
5656
]
5757
for url in urls:
5858
removed = remove_password_from_url(url, replace_with)
59-
expected = url.replace('PASS', replace_with)
59+
expected = url.replace("PASS", replace_with)
6060
removed = remove_password_from_url(url, replace_with)
6161
self.assertEqual(removed, expected)

tests/test_diff_tables.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from data_diff.databases import connect_to_uri
1010
from data_diff.diff_tables import TableDiffer, TableSegment, split_space
1111
from data_diff import databases as db
12+
from data_diff.utils import ArithAlphanumeric
1213

1314
from .common import (
1415
TEST_MYSQL_CONN_STRING,
@@ -369,7 +370,7 @@ def test_diff_sorted_by_key(self):
369370

370371

371372
@test_per_database
372-
class TestStringKeys(TestPerDatabase):
373+
class TestUUIDs(TestPerDatabase):
373374
def setUp(self):
374375
super().setUp()
375376

@@ -408,6 +409,57 @@ def test_string_keys(self):
408409
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
409410

410411

412+
@test_per_database
413+
class TestAlphanumericKeys(TestPerDatabase):
414+
def setUp(self):
415+
super().setUp()
416+
417+
queries = [
418+
f"CREATE TABLE {self.table_src}(id varchar(100), text_comment varchar(1000))",
419+
]
420+
for i in range(0, 10000, 1000):
421+
queries.append(f"INSERT INTO {self.table_src} VALUES ('{ArithAlphanumeric(int=i, max_len=10)}', '{i}')")
422+
423+
queries += [
424+
f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}",
425+
]
426+
427+
self.new_alphanum = "abcdefghij"
428+
queries.append(f"INSERT INTO {self.table_src} VALUES ('{self.new_alphanum}', 'This one is different')")
429+
430+
# TODO test unexpected values?
431+
432+
for query in queries:
433+
self.connection.query(query, None)
434+
435+
_commit(self.connection)
436+
437+
self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False)
438+
self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False)
439+
440+
def test_alphanum_keys(self):
441+
# Test the class itself
442+
assert str(ArithAlphanumeric(int=0, max_len=1)) == "0"
443+
assert str(ArithAlphanumeric(int=0, max_len=10)) == "0" * 10
444+
assert str(ArithAlphanumeric(int=1, max_len=10)) == "0" * 9 + "1"
445+
446+
# Test in the differ
447+
448+
differ = TableDiffer()
449+
diff = list(differ.diff_tables(self.a, self.b))
450+
self.assertEqual(diff, [("-", (str(self.new_alphanum), "This one is different"))])
451+
452+
self.connection.query(
453+
f"INSERT INTO {self.table_src} VALUES ('@@@', '<-- this bad value should not break us')", None
454+
)
455+
_commit(self.connection)
456+
457+
self.a = TableSegment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False)
458+
self.b = TableSegment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False)
459+
460+
self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b))
461+
462+
411463
@test_per_database
412464
class TestTableSegment(TestPerDatabase):
413465
def setUp(self) -> None:

0 commit comments

Comments
 (0)