Skip to content

Commit 0fa11c7

Browse files
committed
Refactor (removed 35 lines)
1 parent bc451cd commit 0fa11c7

File tree

3 files changed

+16
-51
lines changed

3 files changed

+16
-51
lines changed

data_diff/diff_tables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def __post_init__(self):
9191
raise ValueError(f"Error: min_key expected to be smaller than max_key! ({self.min_key} >= {self.max_key})")
9292

9393
if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update:
94-
raise ValueError(f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})")
94+
raise ValueError(
95+
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
96+
)
9597

9698
@property
9799
def _update_column(self):

tests/test_database_types.py

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from datetime import datetime, timedelta, timezone
99
import logging
1010
from decimal import Decimal
11+
from itertools import islice, accumulate, repeat, chain
12+
1113
from parameterized import parameterized
1214

1315
from data_diff import databases as db
@@ -290,54 +292,28 @@ def __init__(self, max):
290292
self.max = max
291293

292294
def __iter__(self):
293-
iter = DateTimeFaker(self.max)
294-
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
295-
iter.i = 0
296-
return iter
295+
initial = datetime(2000, 1, 1, 0, 0, 0, 0)
296+
step = timedelta(seconds=3, microseconds=571)
297+
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
297298

298299
def __len__(self):
299300
return self.max
300301

301-
def __next__(self) -> datetime:
302-
if self.i < len(self.MANUAL_FAKES):
303-
fake = self.MANUAL_FAKES[self.i]
304-
self.i += 1
305-
return fake
306-
elif self.i < self.max:
307-
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
308-
self.i += 1
309-
return self.prev
310-
else:
311-
raise StopIteration
312-
313302

314303
class IntFaker:
315-
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]
304+
MANUAL_FAKES = [127, -3, -9, 37, 15, 0]
316305

317306
def __init__(self, max):
318307
self.max = max
319308

320309
def __iter__(self):
321-
iter = IntFaker(self.max)
322-
iter.prev = -128
323-
iter.i = 0
324-
return iter
310+
initial = -128
311+
step = 1
312+
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
325313

326314
def __len__(self):
327315
return self.max
328316

329-
def __next__(self) -> int:
330-
if self.i < len(self.MANUAL_FAKES):
331-
fake = self.MANUAL_FAKES[self.i]
332-
self.i += 1
333-
return fake
334-
elif self.i < self.max:
335-
self.prev += 1
336-
self.i += 1
337-
return self.prev
338-
else:
339-
raise StopIteration
340-
341317

342318
class FloatFaker:
343319
MANUAL_FAKES = [
@@ -363,26 +339,13 @@ def __init__(self, max):
363339
self.max = max
364340

365341
def __iter__(self):
366-
iter = FloatFaker(self.max)
367-
iter.prev = -10.0001
368-
iter.i = 0
369-
return iter
342+
initial = -10.0001
343+
step = 0.00571
344+
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
370345

371346
def __len__(self):
372347
return self.max
373348

374-
def __next__(self) -> float:
375-
if self.i < len(self.MANUAL_FAKES):
376-
fake = self.MANUAL_FAKES[self.i]
377-
self.i += 1
378-
return fake
379-
elif self.i < self.max:
380-
self.prev += 0.00571
381-
self.i += 1
382-
return self.prev
383-
else:
384-
raise StopIteration
385-
386349

387350
class UUID_Faker:
388351
def __init__(self, max):

tests/test_diff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import preql
77
import arrow # comes with preql
88

9-
from data_diff.databases import connect
9+
from data_diff.databases.connect import connect
1010
from data_diff.diff_tables import TableDiffer, TableSegment, split_space
1111
from data_diff import databases as db
1212
from data_diff.utils import ArithAlphanumeric

0 commit comments

Comments
 (0)