Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 35ae1be

Browse files
authored
Merge pull request #215 from datafold/aug24
Cleanup
2 parents eeb33a9 + bcda6fc commit 35ae1be

File tree

10 files changed

+36
-39
lines changed

10 files changed

+36
-39
lines changed

data_diff/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,14 @@ def _main(
262262
print(f"Diff-Percent: {percent:.14f}%")
263263
print(f"Diff-Split: +{plus} -{minus}")
264264
else:
265-
for op, columns in diff_iter:
265+
for op, values in diff_iter:
266266
color = COLOR_SCHEME[op]
267267

268268
if json_output:
269-
jsonl = json.dumps([op, list(columns)])
269+
jsonl = json.dumps([op, list(values)])
270270
rich.print(f"[{color}]{jsonl}[/{color}]")
271271
else:
272-
text = f"{op} {', '.join(columns)}"
272+
text = f"{op} {', '.join(values)}"
273273
rich.print(f"[{color}]{text}[/{color}]")
274274

275275
sys.stdout.flush()

data_diff/databases/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import sys
33
import logging
44
from typing import Dict, Tuple, Optional, Sequence, Type, List
5-
from functools import lru_cache, wraps
5+
from functools import wraps
66
from concurrent.futures import ThreadPoolExecutor
77
import threading
88
from abc import abstractmethod
99

10-
from data_diff.utils import CaseAwareMapping, is_uuid, safezip
10+
from data_diff.utils import is_uuid, safezip
1111
from .database_types import (
1212
AbstractDatabase,
1313
ColType,
@@ -92,7 +92,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
9292
logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
9393
if getattr(self, "_interactive", False) and isinstance(sql_ast, Select):
9494
explained_sql = compiler.compile(Explain(sql_ast))
95-
logger.info(f"EXPLAIN for SQL SELECT")
95+
logger.info("EXPLAIN for SQL SELECT")
9696
logger.info(self._query(explained_sql))
9797
answer = input("Continue? [y/n] ")
9898
if not answer.lower() in ["y", "yes"]:
@@ -108,7 +108,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
108108
assert len(res) == 1, (sql_code, res)
109109
return res[0]
110110
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
111-
if res_type.__args__ == (int,) or res_type.__args__ == (str,):
111+
if res_type.__args__ in ((int,), (str,)):
112112
return [_one(row) for row in res]
113113
elif res_type.__args__ == (Tuple,):
114114
return [tuple(row) for row in res]
@@ -271,7 +271,7 @@ def concat(self, l: List[str]) -> str:
271271
return f"concat({joined_exprs})"
272272

273273
def timestamp_value(self, t: DbTime) -> str:
274-
return "'%s'" % t.isoformat()
274+
return f"'{t.isoformat()}'"
275275

276276
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
277277
if isinstance(coltype, String_UUID):

data_diff/databases/database_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
class ColType:
1717
supported = True
18-
pass
1918

2019

2120
@dataclass
@@ -141,7 +140,7 @@ def to_string(self, s: str) -> str:
141140
...
142141

143142
@abstractmethod
144-
def concat(self, s: List[str]) -> str:
143+
def concat(self, l: List[str]) -> str:
145144
"Provide SQL for concatenating a bunch of column into a string"
146145
...
147146

@@ -263,6 +262,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
263262
return self.normalize_uuid(value, coltype)
264263
return self.to_string(value)
265264

265+
@abstractmethod
266266
def _normalize_table_path(self, path: DbPath) -> DbPath:
267267
...
268268

data_diff/databases/databricks.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import math
32

43
from .database_types import *
54
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name
@@ -125,9 +124,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
125124
if coltype.rounds:
126125
timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)"
127126
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
128-
else:
129-
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
130-
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
127+
128+
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
129+
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
131130

132131
def normalize_number(self, value: str, coltype: NumericType) -> str:
133132
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")

data_diff/databases/mysql.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def create_connection(self):
4646
raise ConnectError("Bad user name or password") from e
4747
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
4848
raise ConnectError("Database does not exist") from e
49-
else:
50-
raise ConnectError(*e.args) from e
49+
raise ConnectError(*e.args) from e
5150

5251
def quote(self, s: str):
5352
return f"`{s}`"

data_diff/databases/oracle.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import re
2-
31
from ..utils import match_regexps
42

53
from .database_types import *
64
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
7-
from .base import DEFAULT_DATETIME_PRECISION, TIMESTAMP_PRECISION_POS
5+
from .base import TIMESTAMP_PRECISION_POS
86

97
SESSION_TIME_ZONE = None # Changed by the tests
108

@@ -29,7 +27,7 @@ class Oracle(ThreadedDatabase):
2927
ROUNDS_ON_PREC_LOSS = True
3028

3129
def __init__(self, *, host, database, thread_count, **kw):
32-
self.kwargs = dict(dsn="%s/%s" % (host, database) if database else host, **kw)
30+
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)
3331

3432
self.default_schema = kw.get("user")
3533

@@ -73,12 +71,12 @@ def select_table_schema(self, path: DbPath) -> str:
7371
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
7472
if coltype.rounds:
7573
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
74+
75+
if coltype.precision > 0:
76+
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
7677
else:
77-
if coltype.precision > 0:
78-
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
79-
else:
80-
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
81-
return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"
78+
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
79+
return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"
8280

8381
def normalize_number(self, value: str, coltype: FractionalType) -> str:
8482
# FM999.9990
@@ -89,7 +87,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
8987

9088
def _parse_type(
9189
self,
92-
table_name: DbPath,
90+
table_path: DbPath,
9391
col_name: str,
9492
type_repr: str,
9593
datetime_precision: int = None,
@@ -107,7 +105,7 @@ def _parse_type(
107105
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)
108106

109107
return super()._parse_type(
110-
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
108+
table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
111109
)
112110

113111
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):

data_diff/databases/presto.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
from ..utils import match_regexps
44

55
from .database_types import *
6-
from .base import Database, import_helper, _query_conn
6+
from .base import Database, import_helper
77
from .base import (
88
MD5_HEXDIGITS,
99
CHECKSUM_HEXDIGITS,
1010
TIMESTAMP_PRECISION_POS,
11-
DEFAULT_DATETIME_PRECISION,
12-
DEFAULT_NUMERIC_PRECISION,
1311
)
1412

1513

data_diff/diff_tables.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212

1313
from runtype import dataclass
1414

15-
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Value
15+
from .sql import Select, Checksum, Compare, Count, TableName, Time, Value
1616
from .utils import CaseAwareMapping, CaseInsensitiveDict, safezip, split_space, CaseSensitiveDict, ArithString
1717
from .databases.base import Database
1818
from .databases.database_types import (
19+
DbPath,
20+
DbKey,
21+
DbTime,
1922
IKey,
2023
Native_UUID,
2124
NumericType,
@@ -269,7 +272,7 @@ def diff_sets(a: set, b: set) -> Iterator:
269272
for i in s2 - s1:
270273
d[i[0]].append(("+", i))
271274

272-
for k, v in sorted(d.items(), key=lambda i: i[0]):
275+
for _k, v in sorted(d.items(), key=lambda i: i[0]):
273276
yield from v
274277

275278

data_diff/sql.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Provides classes for a pseudo-SQL AST that compiles to SQL code
22
"""
33

4-
from typing import List, Sequence, Union, Tuple, Optional
4+
from typing import Sequence, Union, Optional
55
from datetime import datetime
66

77
from runtype import dataclass
88

99
from .utils import join_iter, ArithString
1010

11-
from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime
11+
from .databases.database_types import AbstractDatabase, DbPath
1212

1313

1414
class Sql:
@@ -66,11 +66,11 @@ class Value(Sql):
6666

6767
def compile(self, c: Compiler):
6868
if isinstance(self.value, bytes):
69-
return "b'%s'" % self.value.decode()
69+
return f"b'{self.value.decode()}'"
7070
elif isinstance(self.value, str):
71-
return "'%s'" % self.value
71+
return f"'{self.value}'" % self.value
7272
elif isinstance(self.value, ArithString):
73-
return "'%s'" % self.value
73+
return f"'{self.value}'"
7474
return str(self.value)
7575

7676

tests/test_database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TestConnect(unittest.TestCase):
2323
def test_bad_uris(self):
2424
self.assertRaises(ValueError, connect_to_uri, "p")
2525
self.assertRaises(ValueError, connect_to_uri, "postgresql:///bla/foo")
26-
self.assertRaises(ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1")
26+
self.assertRaises(ValueError, connect_to_uri, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
2727
self.assertRaises(
28-
ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup"
28+
ValueError, connect_to_uri, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup"
2929
)

0 commit comments

Comments
 (0)