Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 73e254a

Browse files
authored
Merge pull request #167 from datafold/test_per_database
Initial support for running the tests for multiple databases (replacing TestWithConnection)
2 parents e302802 + 3cf0991 commit 73e254a

File tree

3 files changed

+228
-217
lines changed

3 files changed

+228
-217
lines changed

tests/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import suppress
12
import hashlib
23
import os
34
import string
@@ -86,3 +87,15 @@ def str_to_checksum(str: str):
8687
# 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs
8788
half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS
8889
return int(md5[half_pos:], 16)
90+
91+
92+
def _drop_table_if_exists(conn, table):
93+
with suppress(db.QueryError):
94+
if isinstance(conn, db.Oracle):
95+
conn.query(f"DROP TABLE {table}", None)
96+
conn.query(f"DROP TABLE {table}", None)
97+
else:
98+
conn.query(f"DROP TABLE IF EXISTS {table}", None)
99+
if not isinstance(conn, (db.BigQuery, db.Databricks)):
100+
conn.query("COMMIT", None)
101+

tests/test_database_types.py

Lines changed: 160 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from contextlib import suppress
21
import unittest
32
import time
43
import json
@@ -15,174 +14,14 @@
1514
from data_diff.databases import postgresql, oracle
1615
from data_diff.utils import number_to_human
1716
from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD
18-
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix
17+
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix, _drop_table_if_exists
1918

2019

2120
CONNS = {k: db.connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}
2221

2322
CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
2423
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = 'UTC'
2524

26-
27-
class PaginatedTable:
28-
# We can't query all the rows at once for large tables. It'll occupy too
29-
# much memory.
30-
RECORDS_PER_BATCH = 1000000
31-
32-
def __init__(self, table, conn):
33-
self.table = table
34-
self.conn = conn
35-
36-
def __iter__(self):
37-
iter = PaginatedTable(self.table, self.conn)
38-
iter.last_id = 0
39-
iter.values = []
40-
iter.value_index = 0
41-
return iter
42-
43-
def __next__(self) -> str:
44-
if self.value_index == len(self.values): # end of current batch
45-
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
46-
if isinstance(self.conn, db.Oracle):
47-
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"
48-
49-
self.values = self.conn.query(query, list)
50-
if len(self.values) == 0: # we must be done!
51-
raise StopIteration
52-
self.last_id = self.values[-1][0]
53-
self.value_index = 0
54-
55-
this_value = self.values[self.value_index]
56-
self.value_index += 1
57-
return this_value
58-
59-
60-
class DateTimeFaker:
61-
MANUAL_FAKES = [
62-
datetime.fromisoformat("2020-01-01 15:10:10"),
63-
datetime.fromisoformat("2020-02-01 09:09:09"),
64-
datetime.fromisoformat("2022-03-01 15:10:01.139"),
65-
datetime.fromisoformat("2022-04-01 15:10:02.020409"),
66-
datetime.fromisoformat("2022-05-01 15:10:03.003030"),
67-
datetime.fromisoformat("2022-06-01 15:10:05.009900"),
68-
]
69-
70-
def __init__(self, max):
71-
self.max = max
72-
73-
def __iter__(self):
74-
iter = DateTimeFaker(self.max)
75-
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
76-
iter.i = 0
77-
return iter
78-
79-
def __len__(self):
80-
return self.max
81-
82-
def __next__(self) -> datetime:
83-
if self.i < len(self.MANUAL_FAKES):
84-
fake = self.MANUAL_FAKES[self.i]
85-
self.i += 1
86-
return fake
87-
elif self.i < self.max:
88-
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
89-
self.i += 1
90-
return self.prev
91-
else:
92-
raise StopIteration
93-
94-
95-
class IntFaker:
96-
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]
97-
98-
def __init__(self, max):
99-
self.max = max
100-
101-
def __iter__(self):
102-
iter = IntFaker(self.max)
103-
iter.prev = -128
104-
iter.i = 0
105-
return iter
106-
107-
def __len__(self):
108-
return self.max
109-
110-
def __next__(self) -> int:
111-
if self.i < len(self.MANUAL_FAKES):
112-
fake = self.MANUAL_FAKES[self.i]
113-
self.i += 1
114-
return fake
115-
elif self.i < self.max:
116-
self.prev += 1
117-
self.i += 1
118-
return self.prev
119-
else:
120-
raise StopIteration
121-
122-
123-
class FloatFaker:
124-
MANUAL_FAKES = [
125-
0.0,
126-
0.1,
127-
0.00188,
128-
0.99999,
129-
0.091919,
130-
0.10,
131-
10.0,
132-
100.98,
133-
0.001201923076923077,
134-
1 / 3,
135-
1 / 5,
136-
1 / 109,
137-
1 / 109489,
138-
1 / 1094893892389,
139-
1 / 10948938923893289,
140-
3.141592653589793,
141-
]
142-
143-
def __init__(self, max):
144-
self.max = max
145-
146-
def __iter__(self):
147-
iter = FloatFaker(self.max)
148-
iter.prev = -10.0001
149-
iter.i = 0
150-
return iter
151-
152-
def __len__(self):
153-
return self.max
154-
155-
def __next__(self) -> float:
156-
if self.i < len(self.MANUAL_FAKES):
157-
fake = self.MANUAL_FAKES[self.i]
158-
self.i += 1
159-
return fake
160-
elif self.i < self.max:
161-
self.prev += 0.00571
162-
self.i += 1
163-
return self.prev
164-
else:
165-
raise StopIteration
166-
167-
168-
class UUID_Faker:
169-
def __init__(self, max):
170-
self.max = max
171-
172-
def __len__(self):
173-
return self.max
174-
175-
def __iter__(self):
176-
return (uuid.uuid1(i) for i in range(self.max))
177-
178-
179-
TYPE_SAMPLES = {
180-
"int": IntFaker(N_SAMPLES),
181-
"datetime": DateTimeFaker(N_SAMPLES),
182-
"float": FloatFaker(N_SAMPLES),
183-
"uuid": UUID_Faker(N_SAMPLES),
184-
}
185-
18625
DATABASE_TYPES = {
18726
db.PostgreSQL: {
18827
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
@@ -399,6 +238,165 @@ def __iter__(self):
399238
}
400239

401240

241+
class PaginatedTable:
242+
# We can't query all the rows at once for large tables. It'll occupy too
243+
# much memory.
244+
RECORDS_PER_BATCH = 1000000
245+
246+
def __init__(self, table, conn):
247+
self.table = table
248+
self.conn = conn
249+
250+
def __iter__(self):
251+
iter = PaginatedTable(self.table, self.conn)
252+
iter.last_id = 0
253+
iter.values = []
254+
iter.value_index = 0
255+
return iter
256+
257+
def __next__(self) -> str:
258+
if self.value_index == len(self.values): # end of current batch
259+
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
260+
if isinstance(self.conn, db.Oracle):
261+
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"
262+
263+
self.values = self.conn.query(query, list)
264+
if len(self.values) == 0: # we must be done!
265+
raise StopIteration
266+
self.last_id = self.values[-1][0]
267+
self.value_index = 0
268+
269+
this_value = self.values[self.value_index]
270+
self.value_index += 1
271+
return this_value
272+
273+
274+
class DateTimeFaker:
275+
MANUAL_FAKES = [
276+
datetime.fromisoformat("2020-01-01 15:10:10"),
277+
datetime.fromisoformat("2020-02-01 09:09:09"),
278+
datetime.fromisoformat("2022-03-01 15:10:01.139"),
279+
datetime.fromisoformat("2022-04-01 15:10:02.020409"),
280+
datetime.fromisoformat("2022-05-01 15:10:03.003030"),
281+
datetime.fromisoformat("2022-06-01 15:10:05.009900"),
282+
]
283+
284+
def __init__(self, max):
285+
self.max = max
286+
287+
def __iter__(self):
288+
iter = DateTimeFaker(self.max)
289+
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
290+
iter.i = 0
291+
return iter
292+
293+
def __len__(self):
294+
return self.max
295+
296+
def __next__(self) -> datetime:
297+
if self.i < len(self.MANUAL_FAKES):
298+
fake = self.MANUAL_FAKES[self.i]
299+
self.i += 1
300+
return fake
301+
elif self.i < self.max:
302+
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
303+
self.i += 1
304+
return self.prev
305+
else:
306+
raise StopIteration
307+
308+
309+
class IntFaker:
310+
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]
311+
312+
def __init__(self, max):
313+
self.max = max
314+
315+
def __iter__(self):
316+
iter = IntFaker(self.max)
317+
iter.prev = -128
318+
iter.i = 0
319+
return iter
320+
321+
def __len__(self):
322+
return self.max
323+
324+
def __next__(self) -> int:
325+
if self.i < len(self.MANUAL_FAKES):
326+
fake = self.MANUAL_FAKES[self.i]
327+
self.i += 1
328+
return fake
329+
elif self.i < self.max:
330+
self.prev += 1
331+
self.i += 1
332+
return self.prev
333+
else:
334+
raise StopIteration
335+
336+
337+
class FloatFaker:
338+
MANUAL_FAKES = [
339+
0.0,
340+
0.1,
341+
0.00188,
342+
0.99999,
343+
0.091919,
344+
0.10,
345+
10.0,
346+
100.98,
347+
0.001201923076923077,
348+
1 / 3,
349+
1 / 5,
350+
1 / 109,
351+
1 / 109489,
352+
1 / 1094893892389,
353+
1 / 10948938923893289,
354+
3.141592653589793,
355+
]
356+
357+
def __init__(self, max):
358+
self.max = max
359+
360+
def __iter__(self):
361+
iter = FloatFaker(self.max)
362+
iter.prev = -10.0001
363+
iter.i = 0
364+
return iter
365+
366+
def __len__(self):
367+
return self.max
368+
369+
def __next__(self) -> float:
370+
if self.i < len(self.MANUAL_FAKES):
371+
fake = self.MANUAL_FAKES[self.i]
372+
self.i += 1
373+
return fake
374+
elif self.i < self.max:
375+
self.prev += 0.00571
376+
self.i += 1
377+
return self.prev
378+
else:
379+
raise StopIteration
380+
381+
382+
class UUID_Faker:
383+
def __init__(self, max):
384+
self.max = max
385+
386+
def __len__(self):
387+
return self.max
388+
389+
def __iter__(self):
390+
return (uuid.uuid1(i) for i in range(self.max))
391+
392+
393+
TYPE_SAMPLES = {
394+
"int": IntFaker(N_SAMPLES),
395+
"datetime": DateTimeFaker(N_SAMPLES),
396+
"float": FloatFaker(N_SAMPLES),
397+
"uuid": UUID_Faker(N_SAMPLES),
398+
}
399+
402400
type_pairs = []
403401
for source_db, source_type_categories in DATABASE_TYPES.items():
404402
for target_db, target_type_categories in DATABASE_TYPES.items():
@@ -549,17 +547,6 @@ def _create_table_with_indexes(conn, table, type):
549547
conn.query("COMMIT", None)
550548

551549

552-
def _drop_table_if_exists(conn, table):
553-
with suppress(db.QueryError):
554-
if isinstance(conn, db.Oracle):
555-
conn.query(f"DROP TABLE {table}", None)
556-
conn.query(f"DROP TABLE {table}", None)
557-
else:
558-
conn.query(f"DROP TABLE IF EXISTS {table}", None)
559-
if not isinstance(conn, (db.BigQuery, db.Databricks)):
560-
conn.query("COMMIT", None)
561-
562-
563550
class TestDiffCrossDatabaseTables(unittest.TestCase):
564551
maxDiff = 10000
565552

0 commit comments

Comments
 (0)