Skip to content

Commit 1790a38

Browse files
committed
support multithreading for databricks
The databricks connector is not thread-safe so we should inherit ThreadedDatabase class
1 parent b82b3ed commit 1790a38

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

data_diff/databases/databricks.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
ColType,
1414
UnknownColType,
1515
)
16-
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, Database, import_helper, parse_table_name
16+
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name
1717

1818

1919
@import_helper(text="You can install it using 'pip install databricks-sql-connector'")
@@ -68,43 +68,45 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
6868
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
6969

7070

71-
class Databricks(Database):
71+
class Databricks(ThreadedDatabase):
7272
dialect = Dialect()
7373

74-
def __init__(
75-
self,
76-
http_path: str,
77-
access_token: str,
78-
server_hostname: str,
79-
catalog: str = "hive_metastore",
80-
schema: str = "default",
81-
**kwargs,
82-
):
83-
databricks = import_databricks()
84-
85-
self._conn = databricks.sql.connect(
86-
server_hostname=server_hostname, http_path=http_path, access_token=access_token, catalog=catalog
87-
)
88-
74+
def __init__(self, *, thread_count, **kw):
8975
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
9076

91-
self.catalog = catalog
92-
self.default_schema = schema
93-
self.kwargs = kwargs
77+
self._args = kw
78+
self.default_schema = kw.get('schema', 'hive_metastore')
79+
super().__init__(thread_count=thread_count)
9480

95-
def _query(self, sql_code: str) -> list:
96-
"Uses the standard SQL cursor interface"
97-
return self._query_conn(self._conn, sql_code)
81+
def create_connection(self):
82+
databricks = import_databricks()
83+
84+
try:
85+
return databricks.sql.connect(
86+
server_hostname=self._args['server_hostname'],
87+
http_path=self._args['http_path'],
88+
access_token=self._args['access_token'],
89+
catalog=self._args['catalog'],
90+
)
91+
except databricks.sql.exc.Error as e:
92+
raise ConnectionError(*e.args) from e
9893

9994
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10095
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
10196
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
10297
# So, to obtain information about schema, we should use another approach.
10398

99+
conn = self.create_connection()
100+
104101
schema, table = self._normalize_table_path(path)
105-
with self._conn.cursor() as cursor:
106-
cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table)
107-
rows = cursor.fetchall()
102+
with conn.cursor() as cursor:
103+
cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table)
104+
try:
105+
rows = cursor.fetchall()
106+
except:
107+
rows = None
108+
finally:
109+
conn.close()
108110
if not rows:
109111
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
110112

@@ -121,7 +123,7 @@ def _process_table_schema(
121123
resulted_rows = []
122124
for row in rows:
123125
row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1]
124-
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
126+
type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType)
125127

126128
if issubclass(type_cls, Integer):
127129
row = (row[0], row_type, None, None, 0)
@@ -152,9 +154,6 @@ def parse_table_name(self, name: str) -> DbPath:
152154
path = parse_table_name(name)
153155
return self._normalize_table_path(path)
154156

155-
def close(self):
156-
self._conn.close()
157-
158157
@property
159158
def is_autocommit(self) -> bool:
160159
return True

0 commit comments

Comments
 (0)