Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import suppress
import hashlib
import os
import string
Expand Down Expand Up @@ -86,3 +87,15 @@ def str_to_checksum(str: str):
# 0-indexed, unlike DBs which are 1-indexed here, so +1 in dbs
half_pos = db.MD5_HEXDIGITS - db.CHECKSUM_HEXDIGITS
return int(md5[half_pos:], 16)


def _drop_table_if_exists(conn, table):
with suppress(db.QueryError):
if isinstance(conn, db.Oracle):
conn.query(f"DROP TABLE {table}", None)
conn.query(f"DROP TABLE {table}", None)
else:
conn.query(f"DROP TABLE IF EXISTS {table}", None)
if not isinstance(conn, (db.BigQuery, db.Databricks)):
conn.query("COMMIT", None)

333 changes: 160 additions & 173 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from contextlib import suppress
import unittest
import time
import json
Expand All @@ -15,174 +14,14 @@
from data_diff.databases import postgresql, oracle
from data_diff.utils import number_to_human
from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix, _drop_table_if_exists


CONNS = {k: db.connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}

CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = 'UTC'


class PaginatedTable:
# We can't query all the rows at once for large tables. It'll occupy too
# much memory.
RECORDS_PER_BATCH = 1000000

def __init__(self, table, conn):
self.table = table
self.conn = conn

def __iter__(self):
iter = PaginatedTable(self.table, self.conn)
iter.last_id = 0
iter.values = []
iter.value_index = 0
return iter

def __next__(self) -> str:
if self.value_index == len(self.values): # end of current batch
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
if isinstance(self.conn, db.Oracle):
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"

self.values = self.conn.query(query, list)
if len(self.values) == 0: # we must be done!
raise StopIteration
self.last_id = self.values[-1][0]
self.value_index = 0

this_value = self.values[self.value_index]
self.value_index += 1
return this_value


class DateTimeFaker:
MANUAL_FAKES = [
datetime.fromisoformat("2020-01-01 15:10:10"),
datetime.fromisoformat("2020-02-01 09:09:09"),
datetime.fromisoformat("2022-03-01 15:10:01.139"),
datetime.fromisoformat("2022-04-01 15:10:02.020409"),
datetime.fromisoformat("2022-05-01 15:10:03.003030"),
datetime.fromisoformat("2022-06-01 15:10:05.009900"),
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = DateTimeFaker(self.max)
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> datetime:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
self.i += 1
return self.prev
else:
raise StopIteration


class IntFaker:
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = IntFaker(self.max)
iter.prev = -128
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> int:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev += 1
self.i += 1
return self.prev
else:
raise StopIteration


class FloatFaker:
MANUAL_FAKES = [
0.0,
0.1,
0.00188,
0.99999,
0.091919,
0.10,
10.0,
100.98,
0.001201923076923077,
1 / 3,
1 / 5,
1 / 109,
1 / 109489,
1 / 1094893892389,
1 / 10948938923893289,
3.141592653589793,
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = FloatFaker(self.max)
iter.prev = -10.0001
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> float:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev += 0.00571
self.i += 1
return self.prev
else:
raise StopIteration


class UUID_Faker:
def __init__(self, max):
self.max = max

def __len__(self):
return self.max

def __iter__(self):
return (uuid.uuid1(i) for i in range(self.max))


TYPE_SAMPLES = {
"int": IntFaker(N_SAMPLES),
"datetime": DateTimeFaker(N_SAMPLES),
"float": FloatFaker(N_SAMPLES),
"uuid": UUID_Faker(N_SAMPLES),
}

DATABASE_TYPES = {
db.PostgreSQL: {
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
Expand Down Expand Up @@ -399,6 +238,165 @@ def __iter__(self):
}


class PaginatedTable:
# We can't query all the rows at once for large tables. It'll occupy too
# much memory.
RECORDS_PER_BATCH = 1000000

def __init__(self, table, conn):
self.table = table
self.conn = conn

def __iter__(self):
iter = PaginatedTable(self.table, self.conn)
iter.last_id = 0
iter.values = []
iter.value_index = 0
return iter

def __next__(self) -> str:
if self.value_index == len(self.values): # end of current batch
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
if isinstance(self.conn, db.Oracle):
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"

self.values = self.conn.query(query, list)
if len(self.values) == 0: # we must be done!
raise StopIteration
self.last_id = self.values[-1][0]
self.value_index = 0

this_value = self.values[self.value_index]
self.value_index += 1
return this_value


class DateTimeFaker:
MANUAL_FAKES = [
datetime.fromisoformat("2020-01-01 15:10:10"),
datetime.fromisoformat("2020-02-01 09:09:09"),
datetime.fromisoformat("2022-03-01 15:10:01.139"),
datetime.fromisoformat("2022-04-01 15:10:02.020409"),
datetime.fromisoformat("2022-05-01 15:10:03.003030"),
datetime.fromisoformat("2022-06-01 15:10:05.009900"),
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = DateTimeFaker(self.max)
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> datetime:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
self.i += 1
return self.prev
else:
raise StopIteration


class IntFaker:
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = IntFaker(self.max)
iter.prev = -128
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> int:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev += 1
self.i += 1
return self.prev
else:
raise StopIteration


class FloatFaker:
MANUAL_FAKES = [
0.0,
0.1,
0.00188,
0.99999,
0.091919,
0.10,
10.0,
100.98,
0.001201923076923077,
1 / 3,
1 / 5,
1 / 109,
1 / 109489,
1 / 1094893892389,
1 / 10948938923893289,
3.141592653589793,
]

def __init__(self, max):
self.max = max

def __iter__(self):
iter = FloatFaker(self.max)
iter.prev = -10.0001
iter.i = 0
return iter

def __len__(self):
return self.max

def __next__(self) -> float:
if self.i < len(self.MANUAL_FAKES):
fake = self.MANUAL_FAKES[self.i]
self.i += 1
return fake
elif self.i < self.max:
self.prev += 0.00571
self.i += 1
return self.prev
else:
raise StopIteration


class UUID_Faker:
def __init__(self, max):
self.max = max

def __len__(self):
return self.max

def __iter__(self):
return (uuid.uuid1(i) for i in range(self.max))


TYPE_SAMPLES = {
"int": IntFaker(N_SAMPLES),
"datetime": DateTimeFaker(N_SAMPLES),
"float": FloatFaker(N_SAMPLES),
"uuid": UUID_Faker(N_SAMPLES),
}

type_pairs = []
for source_db, source_type_categories in DATABASE_TYPES.items():
for target_db, target_type_categories in DATABASE_TYPES.items():
Expand Down Expand Up @@ -549,17 +547,6 @@ def _create_table_with_indexes(conn, table, type):
conn.query("COMMIT", None)


def _drop_table_if_exists(conn, table):
with suppress(db.QueryError):
if isinstance(conn, db.Oracle):
conn.query(f"DROP TABLE {table}", None)
conn.query(f"DROP TABLE {table}", None)
else:
conn.query(f"DROP TABLE IF EXISTS {table}", None)
if not isinstance(conn, (db.BigQuery, db.Databricks)):
conn.query("COMMIT", None)


class TestDiffCrossDatabaseTables(unittest.TestCase):
maxDiff = 10000

Expand Down
Loading