1+ import math
12from typing import Dict , Sequence
23import logging
34
@@ -61,11 +62,14 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6162 return f"date_format({ value } , 'yyyy-MM-dd HH:mm:ss.{ precision_format } ')"
6263
6364 def normalize_number (self , value : str , coltype : NumericType ) -> str :
64- return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
65+ value = f"cast({ value } as decimal(38, { coltype .precision } ))"
66+ if coltype .precision > 0 :
67+ value = f"format_number({ value } , { coltype .precision } )"
68+ return f"replace({ self .to_string (value )} , ',', '')"
6569
6670 def _convert_db_precision_to_digits (self , p : int ) -> int :
67- # Subtracting 1 due to wierd precision issues
68- return max (super ()._convert_db_precision_to_digits (p ) - 1 , 0 )
71+ # Subtracting 2 due to wierd precision issues
72+ return max (super ()._convert_db_precision_to_digits (p ) - 2 , 0 )
6973
7074
7175class Databricks (ThreadedDatabase ):
@@ -75,19 +79,19 @@ def __init__(self, *, thread_count, **kw):
7579 logging .getLogger ("databricks.sql" ).setLevel (logging .WARNING )
7680
7781 self ._args = kw
78- self .default_schema = kw .get (' schema' , ' hive_metastore' )
82+ self .default_schema = kw .get (" schema" , " hive_metastore" )
7983 super ().__init__ (thread_count = thread_count )
8084
8185 def create_connection (self ):
8286 databricks = import_databricks ()
8387
8488 try :
8589 return databricks .sql .connect (
86- server_hostname = self ._args [' server_hostname' ],
87- http_path = self ._args [' http_path' ],
88- access_token = self ._args [' access_token' ],
89- catalog = self ._args [' catalog' ],
90- )
90+ server_hostname = self ._args [" server_hostname" ],
91+ http_path = self ._args [" http_path" ],
92+ access_token = self ._args [" access_token" ],
93+ catalog = self ._args [" catalog" ],
94+ )
9195 except databricks .sql .exc .Error as e :
9296 raise ConnectionError (* e .args ) from e
9397
@@ -100,11 +104,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
100104
101105 schema , table = self ._normalize_table_path (path )
102106 with conn .cursor () as cursor :
103- cursor .columns (catalog_name = self ._args [' catalog' ], schema_name = schema , table_name = table )
107+ cursor .columns (catalog_name = self ._args [" catalog" ], schema_name = schema , table_name = table )
104108 try :
105109 rows = cursor .fetchall ()
106- except :
107- rows = None
108110 finally :
109111 conn .close ()
110112 if not rows :
@@ -129,7 +131,7 @@ def _process_table_schema(
129131 row = (row [0 ], row_type , None , None , 0 )
130132
131133 elif issubclass (type_cls , Float ):
132- numeric_precision = self . _convert_db_precision_to_digits (row [2 ])
134+ numeric_precision = math . ceil (row [2 ] / math . log ( 2 , 10 ) )
133135 row = (row [0 ], row_type , None , numeric_precision , None )
134136
135137 elif issubclass (type_cls , Decimal ):
0 commit comments