Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,14 @@ def _main(
print(f"Diff-Percent: {percent:.14f}%")
print(f"Diff-Split: +{plus} -{minus}")
else:
for op, columns in diff_iter:
for op, values in diff_iter:
color = COLOR_SCHEME[op]

if json_output:
jsonl = json.dumps([op, list(columns)])
jsonl = json.dumps([op, list(values)])
rich.print(f"[{color}]{jsonl}[/{color}]")
else:
text = f"{op} {', '.join(columns)}"
text = f"{op} {', '.join(values)}"
rich.print(f"[{color}]{text}[/{color}]")

sys.stdout.flush()
Expand Down
10 changes: 5 additions & 5 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import sys
import logging
from typing import Dict, Tuple, Optional, Sequence, Type, List
from functools import lru_cache, wraps
from functools import wraps
from concurrent.futures import ThreadPoolExecutor
import threading
from abc import abstractmethod

from data_diff.utils import CaseAwareMapping, is_uuid, safezip
from data_diff.utils import is_uuid, safezip
from .database_types import (
AbstractDatabase,
ColType,
Expand Down Expand Up @@ -92,7 +92,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
if getattr(self, "_interactive", False) and isinstance(sql_ast, Select):
explained_sql = compiler.compile(Explain(sql_ast))
logger.info(f"EXPLAIN for SQL SELECT")
logger.info("EXPLAIN for SQL SELECT")
logger.info(self._query(explained_sql))
answer = input("Continue? [y/n] ")
if not answer.lower() in ["y", "yes"]:
Expand All @@ -108,7 +108,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
assert len(res) == 1, (sql_code, res)
return res[0]
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
if res_type.__args__ == (int,) or res_type.__args__ == (str,):
if res_type.__args__ in ((int,), (str,)):
return [_one(row) for row in res]
elif res_type.__args__ == (Tuple,):
return [tuple(row) for row in res]
Expand Down Expand Up @@ -271,7 +271,7 @@ def concat(self, l: List[str]) -> str:
return f"concat({joined_exprs})"

def timestamp_value(self, t: DbTime) -> str:
return "'%s'" % t.isoformat()
return f"'{t.isoformat()}'"

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
if isinstance(coltype, String_UUID):
Expand Down
4 changes: 2 additions & 2 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

class ColType:
supported = True
pass


@dataclass
Expand Down Expand Up @@ -141,7 +140,7 @@ def to_string(self, s: str) -> str:
...

@abstractmethod
def concat(self, s: List[str]) -> str:
def concat(self, l: List[str]) -> str:
"Provide SQL for concatenating a bunch of column into a string"
...

Expand Down Expand Up @@ -263,6 +262,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.normalize_uuid(value, coltype)
return self.to_string(value)

@abstractmethod
def _normalize_table_path(self, path: DbPath) -> DbPath:
...

Expand Down
7 changes: 3 additions & 4 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import math

from .database_types import *
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name
Expand Down Expand Up @@ -125,9 +124,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)"
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
else:
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"

precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"

def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
Expand Down
3 changes: 1 addition & 2 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def create_connection(self):
raise ConnectError("Bad user name or password") from e
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
raise ConnectError("Database does not exist") from e
else:
raise ConnectError(*e.args) from e
raise ConnectError(*e.args) from e

def quote(self, s: str):
return f"`{s}`"
Expand Down
20 changes: 9 additions & 11 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import re

from ..utils import match_regexps

from .database_types import *
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
from .base import DEFAULT_DATETIME_PRECISION, TIMESTAMP_PRECISION_POS
from .base import TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests

Expand All @@ -29,7 +27,7 @@ class Oracle(ThreadedDatabase):
ROUNDS_ON_PREC_LOSS = True

def __init__(self, *, host, database, thread_count, **kw):
self.kwargs = dict(dsn="%s/%s" % (host, database) if database else host, **kw)
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

self.default_schema = kw.get("user")

Expand Down Expand Up @@ -73,12 +71,12 @@ def select_table_schema(self, path: DbPath) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"

if coltype.precision > 0:
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
else:
if coltype.precision > 0:
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
else:
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"
truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"

def normalize_number(self, value: str, coltype: FractionalType) -> str:
# FM999.9990
Expand All @@ -89,7 +87,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:

def _parse_type(
self,
table_name: DbPath,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
Expand All @@ -107,7 +105,7 @@ def _parse_type(
return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)

return super()._parse_type(
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
)

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
Expand Down
4 changes: 1 addition & 3 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from ..utils import match_regexps

from .database_types import *
from .base import Database, import_helper, _query_conn
from .base import Database, import_helper
from .base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
TIMESTAMP_PRECISION_POS,
DEFAULT_DATETIME_PRECISION,
DEFAULT_NUMERIC_PRECISION,
)


Expand Down
7 changes: 5 additions & 2 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

from runtype import dataclass

from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Value
from .sql import Select, Checksum, Compare, Count, TableName, Time, Value
from .utils import CaseAwareMapping, CaseInsensitiveDict, safezip, split_space, CaseSensitiveDict, ArithString
from .databases.base import Database
from .databases.database_types import (
DbPath,
DbKey,
DbTime,
IKey,
Native_UUID,
NumericType,
Expand Down Expand Up @@ -269,7 +272,7 @@ def diff_sets(a: set, b: set) -> Iterator:
for i in s2 - s1:
d[i[0]].append(("+", i))

for k, v in sorted(d.items(), key=lambda i: i[0]):
for _k, v in sorted(d.items(), key=lambda i: i[0]):
yield from v


Expand Down
10 changes: 5 additions & 5 deletions data_diff/sql.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Provides classes for a pseudo-SQL AST that compiles to SQL code
"""

from typing import List, Sequence, Union, Tuple, Optional
from typing import Sequence, Union, Optional
from datetime import datetime

from runtype import dataclass

from .utils import join_iter, ArithString

from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime
from .databases.database_types import AbstractDatabase, DbPath


class Sql:
Expand Down Expand Up @@ -66,11 +66,11 @@ class Value(Sql):

def compile(self, c: Compiler):
if isinstance(self.value, bytes):
return "b'%s'" % self.value.decode()
return f"b'{self.value.decode()}'"
elif isinstance(self.value, str):
return "'%s'" % self.value
return f"'{self.value}'" % self.value
elif isinstance(self.value, ArithString):
return "'%s'" % self.value
return f"'{self.value}'"
return str(self.value)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TestConnect(unittest.TestCase):
def test_bad_uris(self):
self.assertRaises(ValueError, connect_to_uri, "p")
self.assertRaises(ValueError, connect_to_uri, "postgresql:///bla/foo")
self.assertRaises(ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1")
self.assertRaises(ValueError, connect_to_uri, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
self.assertRaises(
ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup"
ValueError, connect_to_uri, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup"
)