Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 9c93229

Browse files
committed
fix float value precision calculation
1 parent 1790a38 commit 9c93229

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

data_diff/databases/databricks.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import Dict, Sequence
23
import logging
34

@@ -61,11 +62,14 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6162
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
6263

6364
def normalize_number(self, value: str, coltype: NumericType) -> str:
64-
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
65+
value = f"cast({value} as decimal(38, {coltype.precision}))"
66+
if coltype.precision > 0:
67+
value = f"format_number({value}, {coltype.precision})"
68+
return f"replace({self.to_string(value)}, ',', '')"
6569

6670
def _convert_db_precision_to_digits(self, p: int) -> int:
67-
# Subtracting 1 due to wierd precision issues
68-
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
71+
# Subtracting 2 due to wierd precision issues
72+
return max(super()._convert_db_precision_to_digits(p) - 2, 0)
6973

7074

7175
class Databricks(ThreadedDatabase):
@@ -75,19 +79,19 @@ def __init__(self, *, thread_count, **kw):
7579
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
7680

7781
self._args = kw
78-
self.default_schema = kw.get('schema', 'hive_metastore')
82+
self.default_schema = kw.get("schema", "hive_metastore")
7983
super().__init__(thread_count=thread_count)
8084

8185
def create_connection(self):
8286
databricks = import_databricks()
8387

8488
try:
8589
return databricks.sql.connect(
86-
server_hostname=self._args['server_hostname'],
87-
http_path=self._args['http_path'],
88-
access_token=self._args['access_token'],
89-
catalog=self._args['catalog'],
90-
)
90+
server_hostname=self._args["server_hostname"],
91+
http_path=self._args["http_path"],
92+
access_token=self._args["access_token"],
93+
catalog=self._args["catalog"],
94+
)
9195
except databricks.sql.exc.Error as e:
9296
raise ConnectionError(*e.args) from e
9397

@@ -100,11 +104,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
100104

101105
schema, table = self._normalize_table_path(path)
102106
with conn.cursor() as cursor:
103-
cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table)
107+
cursor.columns(catalog_name=self._args["catalog"], schema_name=schema, table_name=table)
104108
try:
105109
rows = cursor.fetchall()
106-
except:
107-
rows = None
108110
finally:
109111
conn.close()
110112
if not rows:
@@ -129,7 +131,7 @@ def _process_table_schema(
129131
row = (row[0], row_type, None, None, 0)
130132

131133
elif issubclass(type_cls, Float):
132-
numeric_precision = self._convert_db_precision_to_digits(row[2])
134+
numeric_precision = math.ceil(row[2] / math.log(2, 10))
133135
row = (row[0], row_type, None, numeric_precision, None)
134136

135137
elif issubclass(type_cls, Decimal):

0 commit comments

Comments
 (0)