Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 723b00c

Browse files
committed
Now automatically fixing the column case using the schema.
Added the --keep-column-case switch to disable it.
1 parent 66248cf commit 723b00c

File tree

4 files changed

+56
-24
lines changed

4 files changed

+56
-24
lines changed

data_diff/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
@click.option("-d", "--debug", is_flag=True, help="Print debug info")
5454
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
5555
@click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug")
56+
@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.")
5657
@click.option(
5758
"-j",
5859
"--threads",
@@ -79,6 +80,7 @@ def main(
7980
verbose,
8081
interactive,
8182
threads,
83+
keep_column_case,
8284
):
8385
if limit and stats:
8486
print("Error: cannot specify a limit when using the -s/--stats switch")
@@ -119,6 +121,7 @@ def main(
119121
options = dict(
120122
min_update=max_age and parse_time_before_now(max_age),
121123
max_update=min_age and parse_time_before_now(min_age),
124+
case_sensitive=keep_column_case,
122125
)
123126
except ParseError as e:
124127
logging.error("Error while parsing age expression: %s" % e)

data_diff/diff_tables.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Provides classes for performing a table diff
22
"""
33

4+
from abc import ABC, abstractmethod
45
import time
56
from operator import attrgetter, methodcaller
67
from collections import defaultdict
7-
from typing import List, Tuple, Iterator, Optional, Mapping
8+
from typing import List, Tuple, Iterator, Optional
89
import logging
910
from concurrent.futures import ThreadPoolExecutor
1011

@@ -36,24 +37,44 @@ def parse_table_name(t):
3637
return tuple(t.split("."))
3738

3839

39-
class CaseInsensitiveDict(Mapping):
40-
def __init__(self, initial=()):
41-
self._dict = {k.lower(): v for k, v in dict(initial).items()}
40+
class Schema(ABC):
41+
@abstractmethod
42+
def get_key(self, key: str) -> str:
43+
...
4244

43-
def __setitem__(self, key, value):
44-
self._dict[key.lower()] = value
45+
@abstractmethod
46+
def __getitem__(self, key: str) -> str:
47+
...
4548

46-
def __getitem__(self, key):
47-
try:
48-
return self._dict[key.lower()]
49-
except KeyError:
50-
raise
49+
@abstractmethod
50+
def __setitem__(self, key: str, value):
51+
...
5152

52-
def __iter__(self):
53-
return iter(self._dict)
53+
@abstractmethod
54+
def __contains__(self, key: str) -> bool:
55+
...
5456

55-
def __len__(self):
56-
return len(self._dict)
57+
58+
class Schema_CaseSensitive(dict, Schema):
59+
def get_key(self, key):
60+
return key
61+
62+
63+
class Schema_CaseInsensitive(Schema):
64+
def __init__(self, initial):
65+
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}
66+
67+
def get_key(self, key: str) -> str:
68+
return self._dict[key.lower()][0]
69+
70+
def __getitem__(self, key: str) -> str:
71+
return self._dict[key.lower()][1]
72+
73+
def __setitem__(self, key: str, value):
74+
self._dict[key.lower()] = key, value
75+
76+
def __contains__(self, key):
77+
return key.lower() in self._dict
5778

5879

5980
@dataclass(frozen=False)
@@ -88,8 +109,8 @@ class TableSegment:
88109
min_update: DbTime = None
89110
max_update: DbTime = None
90111

91-
quote_columns: bool = True
92-
_schema: Mapping[str, ColType] = None
112+
case_sensitive: bool = True
113+
_schema: Schema = None
93114

94115
def __post_init__(self):
95116
if not self.update_column and (self.min_update or self.max_update):
@@ -110,17 +131,24 @@ def _update_column(self):
110131
return self._quote_column(self.update_column)
111132

112133
def _quote_column(self, c):
113-
if self.quote_columns:
114-
return self.database.quote(c)
115-
return c
134+
if self._schema:
135+
c = self._schema.get_key(c)
136+
return self.database.quote(c)
116137

117138
def with_schema(self) -> "TableSegment":
118139
"Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
119140
if self._schema:
120141
return self
121142
schema = self.database.query_table_schema(self.table_path)
122-
if not self.quote_columns:
123-
schema = CaseInsensitiveDict(schema)
143+
if self.case_sensitive:
144+
schema = Schema_CaseSensitive(schema)
145+
else:
146+
if len({k.lower() for k in schema}) < len(schema):
147+
logger.warn(
148+
f'Ambiguous schema for {self.database}:{".".join(self.table_path)} | Columns = {", ".join(list(schema))}'
149+
)
150+
logger.warn("We recommend to disable case-insensitivity (remove --any-case).")
151+
schema = Schema_CaseInsensitive(schema)
124152
return self.new(_schema=schema)
125153

126154
def _make_key_range(self):

tests/test_database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_md5_to_int(self):
1818

1919
self.assertEqual(str_to_checksum(str), self.mysql.query(query, int))
2020

21+
2122
class TestConnect(unittest.TestCase):
2223
def test_bad_uris(self):
2324
self.assertRaises(ValueError, connect_to_uri, "p")

tests/test_database_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
214214
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type});", None)
215215
_insert_to_table(dst_conn, dst_table, values_in_source)
216216

217-
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), quote_columns=False)
218-
self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), quote_columns=False)
217+
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False)
218+
self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False)
219219

220220
self.assertEqual(len(sample_values), self.table.count())
221221
self.assertEqual(len(sample_values), self.table2.count())

0 commit comments

Comments
 (0)