1313 ColType ,
1414 UnknownColType ,
1515)
16- from  .base  import  MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , Database , import_helper , parse_table_name 
16+ from  .base  import  MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , ThreadedDatabase , import_helper , parse_table_name 
1717
1818
1919@import_helper (text = "You can install it using 'pip install databricks-sql-connector'" ) 
@@ -68,43 +68,45 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
6868 return  max (super ()._convert_db_precision_to_digits (p ) -  1 , 0 )
6969
7070
71- class  Databricks (Database ):
71+ class  Databricks (ThreadedDatabase ):
7272 dialect  =  Dialect ()
7373
74-  def  __init__ (
75-  self ,
76-  http_path : str ,
77-  access_token : str ,
78-  server_hostname : str ,
79-  catalog : str  =  "hive_metastore" ,
80-  schema : str  =  "default" ,
81-  ** kwargs ,
82-  ):
83-  databricks  =  import_databricks ()
84- 
85-  self ._conn  =  databricks .sql .connect (
86-  server_hostname = server_hostname , http_path = http_path , access_token = access_token , catalog = catalog 
87-  )
88- 
74+  def  __init__ (self , * , thread_count , ** kw ):
8975 logging .getLogger ("databricks.sql" ).setLevel (logging .WARNING )
9076
91-  self .catalog  =  catalog 
92-  self .default_schema  =  schema 
93-  self . kwargs   =   kwargs 
77+  self ._args  =  kw 
78+  self .default_schema  =  kw . get ( ' schema' ,  'hive_metastore' ) 
79+  super (). __init__ ( thread_count = thread_count ) 
9480
95-  def  _query (self , sql_code : str ) ->  list :
96-  "Uses the standard SQL cursor interface" 
97-  return  self ._query_conn (self ._conn , sql_code )
81+  def  create_connection (self ):
82+  databricks  =  import_databricks ()
83+ 
84+  try :
85+  return  databricks .sql .connect (
86+  server_hostname = self ._args ['server_hostname' ],
87+  http_path = self ._args ['http_path' ],
88+  access_token = self ._args ['access_token' ],
89+  catalog = self ._args ['catalog' ],
90+  )
91+  except  databricks .sql .exc .Error  as  e :
92+  raise  ConnectionError (* e .args ) from  e 
9893
9994 def  query_table_schema (self , path : DbPath ) ->  Dict [str , tuple ]:
10095 # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. 
10196 # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html 
10297 # So, to obtain information about schema, we should use another approach. 
10398
99+  conn  =  self .create_connection ()
100+ 
104101 schema , table  =  self ._normalize_table_path (path )
105-  with  self ._conn .cursor () as  cursor :
106-  cursor .columns (catalog_name = self .catalog , schema_name = schema , table_name = table )
107-  rows  =  cursor .fetchall ()
102+  with  conn .cursor () as  cursor :
103+  cursor .columns (catalog_name = self ._args ['catalog' ], schema_name = schema , table_name = table )
104+  try :
105+  rows  =  cursor .fetchall ()
106+  except :
107+  rows  =  None 
108+  finally :
109+  conn .close ()
108110 if  not  rows :
109111 raise  RuntimeError (f"{ self .name }  : Table '{ '.' .join (path )}  ' does not exist, or has no columns" )
110112
@@ -121,7 +123,7 @@ def _process_table_schema(
121123 resulted_rows  =  []
122124 for  row  in  rows :
123125 row_type  =  "DECIMAL"  if  row [1 ].startswith ("DECIMAL" ) else  row [1 ]
124-  type_cls  =  self .TYPE_CLASSES .get (row_type , UnknownColType )
126+  type_cls  =  self .dialect . TYPE_CLASSES .get (row_type , UnknownColType )
125127
126128 if  issubclass (type_cls , Integer ):
127129 row  =  (row [0 ], row_type , None , None , 0 )
@@ -152,9 +154,6 @@ def parse_table_name(self, name: str) -> DbPath:
152154 path  =  parse_table_name (name )
153155 return  self ._normalize_table_path (path )
154156
155-  def  close (self ):
156-  self ._conn .close ()
157- 
158157 @property  
159158 def  is_autocommit (self ) ->  bool :
160159 return  True 
0 commit comments