1
+ import math
1
2
from typing import Dict , Sequence
2
3
import logging
3
4
@@ -61,11 +62,14 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
61
62
return f"date_format({ value } , 'yyyy-MM-dd HH:mm:ss.{ precision_format } ')"
62
63
63
64
def normalize_number (self , value : str , coltype : NumericType ) -> str :
64
- return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
65
+ value = f"cast({ value } as decimal(38, { coltype .precision } ))"
66
+ if coltype .precision > 0 :
67
+ value = f"format_number({ value } , { coltype .precision } )"
68
+ return f"replace({ self .to_string (value )} , ',', '')"
65
69
66
70
def _convert_db_precision_to_digits (self , p : int ) -> int :
67
- # Subtracting 1 due to wierd precision issues
68
- return max (super ()._convert_db_precision_to_digits (p ) - 1 , 0 )
71
+ # Subtracting 2 due to wierd precision issues
72
+ return max (super ()._convert_db_precision_to_digits (p ) - 2 , 0 )
69
73
70
74
71
75
class Databricks (ThreadedDatabase ):
@@ -75,19 +79,19 @@ def __init__(self, *, thread_count, **kw):
75
79
logging .getLogger ("databricks.sql" ).setLevel (logging .WARNING )
76
80
77
81
self ._args = kw
78
- self .default_schema = kw .get (' schema' , ' hive_metastore' )
82
+ self .default_schema = kw .get (" schema" , " hive_metastore" )
79
83
super ().__init__ (thread_count = thread_count )
80
84
81
85
def create_connection (self ):
82
86
databricks = import_databricks ()
83
87
84
88
try :
85
89
return databricks .sql .connect (
86
- server_hostname = self ._args [' server_hostname' ],
87
- http_path = self ._args [' http_path' ],
88
- access_token = self ._args [' access_token' ],
89
- catalog = self ._args [' catalog' ],
90
- )
90
+ server_hostname = self ._args [" server_hostname" ],
91
+ http_path = self ._args [" http_path" ],
92
+ access_token = self ._args [" access_token" ],
93
+ catalog = self ._args [" catalog" ],
94
+ )
91
95
except databricks .sql .exc .Error as e :
92
96
raise ConnectionError (* e .args ) from e
93
97
@@ -100,11 +104,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
100
104
101
105
schema , table = self ._normalize_table_path (path )
102
106
with conn .cursor () as cursor :
103
- cursor .columns (catalog_name = self ._args [' catalog' ], schema_name = schema , table_name = table )
107
+ cursor .columns (catalog_name = self ._args [" catalog" ], schema_name = schema , table_name = table )
104
108
try :
105
109
rows = cursor .fetchall ()
106
- except :
107
- rows = None
108
110
finally :
109
111
conn .close ()
110
112
if not rows :
@@ -129,7 +131,7 @@ def _process_table_schema(
129
131
row = (row [0 ], row_type , None , None , 0 )
130
132
131
133
elif issubclass (type_cls , Float ):
132
- numeric_precision = self . _convert_db_precision_to_digits (row [2 ])
134
+ numeric_precision = math . ceil (row [2 ] / math . log ( 2 , 10 ) )
133
135
row = (row [0 ], row_type , None , numeric_precision , None )
134
136
135
137
elif issubclass (type_cls , Decimal ):
0 commit comments