Skip to content

Commit 418bc01

Browse files
committed
Reformatted with 'Black', as requested
1 parent 6b36e6a commit 418bc01

File tree

5 files changed

+158
-108
lines changed

5 files changed

+158
-108
lines changed

data-diff/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .database import connect_to_uri
44
from .diff_tables import TableSegment, TableDiffer
55

6+
67
def create_source(db_uri: str, table_name: str, key_column: str, extra_columns: Tuple[str, ...] = ()):
78
db = connect_to_uri(db_uri)
8-
return TableSegment(db, (table_name,), key_column, tuple(extra_columns))
9+
return TableSegment(db, (table_name,), key_column, tuple(extra_columns))

data-diff/__main__.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,37 @@
88

99
import click
1010

11-
LOG_FORMAT = '[%(asctime)s] %(levelname)s - %(message)s'
12-
DATE_FORMAT ='%H:%M:%S'
11+
LOG_FORMAT = "[%(asctime)s] %(levelname)s - %(message)s"
12+
DATE_FORMAT = "%H:%M:%S"
13+
1314

1415
@click.command()
15-
@click.argument('db1_uri')
16-
@click.argument('table1_name')
17-
@click.argument('db2_uri')
18-
@click.argument('table2_name')
19-
@click.option('-k', '--key_column', default='id', help='Name of primary key column')
20-
@click.option('-c', '--columns', default=['updated_at'], multiple=True, help='Names of extra columns to compare')
21-
@click.option('-l', '--limit', default=None, help='Maximum number of differences to find')
22-
@click.option('--bisection-factor', default=32, help='Segments per iteration')
23-
@click.option('--bisection-threshold', default=1024**2, help='Minimal bisection threshold')
24-
@click.option('-s', '--stats', is_flag=True, help='Print stats instead of a detailed diff')
25-
@click.option('-d', '--debug', is_flag=True, help='Print debug info')
26-
@click.option('-v', '--verbose', is_flag=True, help='Print extra info')
27-
def main(db1_uri, table1_name, db2_uri, table2_name, key_column, columns, limit, bisection_factor, bisection_threshold, stats, debug, verbose):
16+
@click.argument("db1_uri")
17+
@click.argument("table1_name")
18+
@click.argument("db2_uri")
19+
@click.argument("table2_name")
20+
@click.option("-k", "--key_column", default="id", help="Name of primary key column")
21+
@click.option("-c", "--columns", default=["updated_at"], multiple=True, help="Names of extra columns to compare")
22+
@click.option("-l", "--limit", default=None, help="Maximum number of differences to find")
23+
@click.option("--bisection-factor", default=32, help="Segments per iteration")
24+
@click.option("--bisection-threshold", default=1024**2, help="Minimal bisection threshold")
25+
@click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff")
26+
@click.option("-d", "--debug", is_flag=True, help="Print debug info")
27+
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
28+
def main(
29+
db1_uri,
30+
table1_name,
31+
db2_uri,
32+
table2_name,
33+
key_column,
34+
columns,
35+
limit,
36+
bisection_factor,
37+
bisection_threshold,
38+
stats,
39+
debug,
40+
verbose,
41+
):
2842
if limit and stats:
2943
print("Error: cannot specify a limit when using the -s/--stats switch")
3044
return
@@ -34,7 +48,6 @@ def main(db1_uri, table1_name, db2_uri, table2_name, key_column, columns, limit,
3448
elif verbose:
3549
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
3650

37-
3851
db1 = connect_to_uri(db1_uri)
3952
db2 = connect_to_uri(db2_uri)
4053

@@ -54,8 +67,8 @@ def main(db1_uri, table1_name, db2_uri, table2_name, key_column, columns, limit,
5467
percent = 100 * len(diff) / table1.count
5568
print(f"Diff-Total: {len(diff)} changed rows out of {table1.count}")
5669
print(f"Diff-Percent: {percent:.4f}%")
57-
plus = len([1 for op,_ in diff if op=='+'])
58-
minus = len([1 for op,_ in diff if op=='-'])
70+
plus = len([1 for op, _ in diff if op == "+"])
71+
minus = len([1 for op, _ in diff if op == "-"])
5972
print(f"Diff-Split: +{plus} -{minus}")
6073
else:
6174
for op, key in diff_iter:
@@ -67,5 +80,5 @@ def main(db1_uri, table1_name, db2_uri, table2_name, key_column, columns, limit,
6780
logging.info(f"Duration: {end-start:.2f} seconds.")
6881

6982

70-
if __name__ == '__main__':
71-
main()
83+
if __name__ == "__main__":
84+
main()

data-diff/database.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,38 @@
77
from .sql import SqlOrStr, Compiler
88

99

10-
logger = logging.getLogger('database')
10+
logger = logging.getLogger("database")
1111

1212

1313
def import_postgres():
1414
import psycopg2
1515
import psycopg2.extras
16+
1617
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
1718
return psycopg2
1819

20+
1921
def import_mysql():
2022
import mysql.connector
23+
2124
return mysql.connector
2225

26+
2327
def import_snowflake():
2428
import snowflake.connector
29+
2530
return snowflake
2631

32+
2733
def import_mssql():
2834
import pymssql
35+
2936
return pymssql
3037

38+
3139
def import_oracle():
3240
import cx_Oracle
41+
3342
return cx_Oracle
3443

3544

@@ -38,12 +47,13 @@ class ConnectError(Exception):
3847

3948

4049
def _one(seq):
41-
x ,= seq
50+
(x,) = seq
4251
return x
4352

53+
4454
class Database(ABC):
4555
"""Base abstract class for databases.
46-
56+
4757
Used for providing connection code and implementation specific SQL utilities.
4858
"""
4959

@@ -60,10 +70,10 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
6070
res = self._query(sql_code)
6171
if res_type is int:
6272
res = _one(_one(res))
63-
if res is None: # May happen due to sum() of 0 items
73+
if res is None: # May happen due to sum() of 0 items
6474
return None
6575
return int(res)
66-
elif getattr(res_type, '__origin__', None) is list and len(res_type.__args__) == 1:
76+
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
6777
if res_type.__args__ == (int,):
6878
return [_one(row) for row in res]
6979
elif res_type.__args__ == (Tuple,):
@@ -87,12 +97,14 @@ def md5_to_int(self, s: str) -> str:
8797
"Provide SQL for computing md5 and returning an int"
8898
...
8999

90-
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower
100+
101+
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower
91102
MD5_HEXDIGITS = 32
92103

93-
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS<<2
104+
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
94105
CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1
95106

107+
96108
class Postgres(Database):
97109
def __init__(self, host, port, database, user, password):
98110
postgres = import_postgres()
@@ -110,18 +122,18 @@ def md5_to_int(self, s: str) -> str:
110122
return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint"
111123

112124
def to_string(self, s: str):
113-
return f'{s}::varchar'
125+
return f"{s}::varchar"
114126

115127

116128
class MySQL(Database):
117129
def __init__(self, host, port, database, user, password):
118130
mysql = import_mysql()
119131

120132
args = dict(host=host, port=port, database=database, user=user, password=password)
121-
self._args = {k:v for k, v in args.items() if v is not None}
133+
self._args = {k: v for k, v in args.items() if v is not None}
122134

123135
try:
124-
self._conn = mysql.connect(charset='utf8', use_unicode=True, **self._args)
136+
self._conn = mysql.connect(charset="utf8", use_unicode=True, **self._args)
125137
except mysql.Error as e:
126138
if e.errno == mysql.errorcode.ER_ACCESS_DENIED_ERROR:
127139
raise ConnectError("Bad user name or password") from e
@@ -131,13 +143,13 @@ def __init__(self, host, port, database, user, password):
131143
raise ConnectError(*e.args) from e
132144

133145
def quote(self, s: str):
134-
return f'`{s}`'
146+
return f"`{s}`"
135147

136148
def md5_to_int(self, s: str) -> str:
137149
return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)"
138150

139151
def to_string(self, s: str):
140-
return f'cast({s} as char)'
152+
return f"cast({s} as char)"
141153

142154

143155
class Oracle(Database):
@@ -155,31 +167,33 @@ def md5_to_int(self, s: str) -> str:
155167
return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')"
156168

157169
def quote(self, s: str):
158-
return f'{s}'
170+
return f"{s}"
159171

160172
def to_string(self, s: str):
161-
return f'cast({s} as varchar(1024))'
173+
return f"cast({s} as varchar(1024))"
174+
162175

163176
class Redshift(Postgres):
164177
def md5_to_int(self, s: str) -> str:
165178
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"
166179

180+
167181
class MsSQL(Database):
168182
"AKA sql-server"
169183

170184
def __init__(self, host, port, database, user, password):
171185
mssql = import_mssql()
172186

173187
args = dict(server=host, port=port, database=database, user=user, password=password)
174-
self._args = {k:v for k, v in args.items() if v is not None}
188+
self._args = {k: v for k, v in args.items() if v is not None}
175189

176190
try:
177191
self._conn = mssql.connect(**self._args)
178192
except mssql.Error as e:
179193
raise ConnectError(*e.args) from e
180194

181195
def quote(self, s: str):
182-
return f'[{s}]'
196+
return f"[{s}]"
183197

184198
def md5_to_int(self, s: str) -> str:
185199
return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))"
@@ -188,13 +202,15 @@ def md5_to_int(self, s: str) -> str:
188202
def to_string(self, s: str):
189203
return f"CONVERT(varchar, {s})"
190204

205+
191206
class BigQuery(Database):
192207
def __init__(self, project, dataset):
193208
from google.cloud import bigquery
209+
194210
self._client = bigquery.Client(project)
195211

196212
def quote(self, s: str):
197-
return f'`{s}`'
213+
return f"`{s}`"
198214

199215
def md5_to_int(self, s: str) -> str:
200216
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)"
@@ -206,29 +222,27 @@ def _canonize_value(self, value):
206222

207223
def _query(self, sql_code: str):
208224
from google.cloud import bigquery
225+
209226
try:
210227
res = list(self._client.query(sql_code))
211228
except Exception as e:
212229
msg = "Exception when trying to execute SQL code:\n %s\n\nGot error: %s"
213-
raise ConnectError(msg%(sql_code, e))
230+
raise ConnectError(msg % (sql_code, e))
214231

215232
if res and isinstance(res[0], bigquery.table.Row):
216233
res = [tuple(self._canonize_value(v) for v in row.values()) for row in res]
217234
return res
218235

219236
def to_string(self, s: str):
220-
return f'cast({s} as string)'
237+
return f"cast({s} as string)"
238+
221239

222240
class Snowflake(Database):
223241
def __init__(self, account, user, password, path, schema, database, print_sql=False):
224242
snowflake = import_snowflake()
225-
logging.getLogger('snowflake.connector').setLevel(logging.WARNING)
243+
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
226244

227-
self._conn = snowflake.connector.connect(
228-
user=user,
229-
password=password,
230-
account=account
231-
)
245+
self._conn = snowflake.connector.connect(user=user, password=password, account=account)
232246
self._conn.cursor().execute(f"USE WAREHOUSE {path.lstrip('/')}")
233247
self._conn.cursor().execute(f"USE DATABASE {database}")
234248
self._conn.cursor().execute(f"USE SCHEMA {schema}")
@@ -240,7 +254,7 @@ def md5_to_int(self, s: str) -> str:
240254
return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})"
241255

242256
def to_string(self, s: str):
243-
return f'cast({s} as string)'
257+
return f"cast({s} as string)"
244258

245259

246260
def connect_to_uri(db_uri: str) -> Database:
@@ -259,28 +273,28 @@ def connect_to_uri(db_uri: str) -> Database:
259273
dsn = dsnparse.parse(db_uri)
260274
if len(dsn.schemes) > 1:
261275
raise NotImplementedError("No support for multiple schemes")
262-
scheme ,= dsn.schemes
276+
(scheme,) = dsn.schemes
263277

264278
if len(dsn.paths) == 0:
265-
path = ''
279+
path = ""
266280
elif len(dsn.paths) == 1:
267-
path ,= dsn.paths
281+
(path,) = dsn.paths
268282
else:
269283
raise ValueError("Bad value for uri, too many paths: %s" % db_uri)
270284

271-
if scheme == 'postgres':
285+
if scheme == "postgres":
272286
return Postgres(dsn.host, dsn.port, path, dsn.user, dsn.password)
273-
elif scheme == 'mysql':
287+
elif scheme == "mysql":
274288
return MySQL(dsn.host, dsn.port, path, dsn.user, dsn.password)
275-
elif scheme == 'snowflake':
289+
elif scheme == "snowflake":
276290
return Snowflake(dsn.host, dsn.user, dsn.password, path, **dsn.query)
277-
elif scheme == 'mssql':
291+
elif scheme == "mssql":
278292
return MsSQL(dsn.host, dsn.port, path, dsn.user, dsn.password)
279-
elif scheme == 'bigquery':
293+
elif scheme == "bigquery":
280294
return BigQuery(dsn.host, path)
281-
elif scheme == 'redshift':
295+
elif scheme == "redshift":
282296
return Redshift(dsn.host, dsn.port, path, dsn.user, dsn.password)
283-
elif scheme == 'oracle':
297+
elif scheme == "oracle":
284298
return Oracle(dsn.host, dsn.port, path, dsn.user, dsn.password)
285299

286300
raise NotImplementedError(f"Scheme {dsn.scheme} currently not supported")

0 commit comments

Comments
 (0)