Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
13 changes: 12 additions & 1 deletion data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,18 @@ class Presto(Database):
def __init__(self, **kw):
prestodb = import_presto()

self._conn = prestodb.dbapi.connect(**kw)
if kw.get("schema"):
self.default_schema = kw.get("schema")

if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto
kw["auth"] = prestodb.auth.BasicAuthentication(kw.pop("user"), kw.pop("password"))

if "cert" in kw: # if a certificate was specified in URI, verify session with cert
cert = kw.pop("cert")
self._conn = prestodb.dbapi.connect(**kw)
self._conn._http_session.verify = cert
else:
self._conn = prestodb.dbapi.connect(**kw)

def quote(self, s: str):
return f'"{s}"'
Expand Down
31 changes: 24 additions & 7 deletions data_diff/databases/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging

from .database_types import *
from .base import Database, import_helper, _query_conn, CHECKSUM_MASK
from .base import ConnectError, Database, import_helper, _query_conn, CHECKSUM_MASK


@import_helper("snowflake")
def import_snowflake():
import snowflake.connector
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

return snowflake
return snowflake, serialization, default_backend


class Snowflake(Database):
Expand All @@ -26,7 +28,7 @@ class Snowflake(Database):
ROUNDS_ON_PREC_LOSS = False

def __init__(self, *, schema: str, **kw):
snowflake = import_snowflake()
snowflake, serialization, default_backend = import_snowflake()
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)

# Got an error: snowflake.connector.network.RetryRequest: could not find io module state (interpreter shutdown?)
Expand All @@ -35,10 +37,25 @@ def __init__(self, *, schema: str, **kw):
logging.getLogger("snowflake.connector.network").disabled = True

assert '"' not in schema, "Schema name should not contain quotes!"
self._conn = snowflake.connector.connect(
schema=f'"{schema}"',
**kw,
)
if (
"key" in kw
): # if private keys are used for Snowflake connection, read in key from path specified and pass as "private_key" to connector.
with open(kw.get("key"), "rb") as key:
if 'password' in kw:
raise ConnectError("Cannot use password and key at the same time")
p_key = serialization.load_pem_private_key(
key.read(),
password=None,
backend=default_backend(),
)

kw["private_key"] = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw)

self.default_schema = schema

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ preql = "^0.2.19"
mysql-connector-python = "*"
databricks-sql-connector = "*"
snowflake-connector-python = "*"
cryptography = "*"
trino = "^0.314.0"
psycopg2 = "*"
presto-python-client = "*"
Expand All @@ -46,7 +47,7 @@ presto-python-client = "*"
preql = ["preql"]
mysql = ["mysql-connector-python"]
postgresql = ["psycopg2"]
snowflake = ["snowflake-connector-python"]
snowflake = ["snowflake-connector-python", "cryptography"]
presto = ["presto-python-client"]
oracle = ["cx_Oracle"]
databricks = ["databricks-sql-connector"]
Expand Down