@@ -330,10 +330,10 @@ def __init__(
330330 account : str ,
331331 user : str ,
332332 password : str ,
333- path : str ,
334- role : str ,
333+ warehouse : str ,
335334 schema : str ,
336335 database : str ,
336+ role : str = None ,
337337 print_sql : bool = False ,
338338 ):
339339 snowflake = import_snowflake ()
@@ -350,7 +350,7 @@ def __init__(
350350 account = account ,
351351 role = role ,
352352 database = database ,
353- warehouse = path . lstrip ( '/' ) ,
353+ warehouse = warehouse ,
354354 schema = schema ,
355355 )
356356
@@ -389,6 +389,14 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
389389 raise NotImplementedError ("No support for multiple schemes" )
390390 (scheme ,) = dsn .schemes
391391
392+ if scheme == 'snowflake' :
393+ database , schema = dsn .paths
394+ try :
395+ warehouse = dsn .query ['warehouse' ]
396+ except KeyError :
397+ raise ValueError ("Must provide warehouse. Format: 'snowflake://<user>:<pass>@<account>/<database>/<schema>?warehouse=<warehouse>'" )
398+ return Snowflake (dsn .host , dsn .user , dsn .password , warehouse = warehouse , database = database , schema = schema )
399+
392400 if len (dsn .paths ) == 0 :
393401 path = ""
394402 elif len (dsn .paths ) == 1 :
@@ -400,8 +408,6 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
400408 return Postgres (dsn .host , dsn .port , path , dsn .user , dsn .password , thread_count = thread_count )
401409 elif scheme == "mysql" :
402410 return MySQL (dsn .host , dsn .port , path , dsn .user , dsn .password , thread_count = thread_count )
403- elif scheme == "snowflake" :
404- return Snowflake (dsn .host , dsn .user , dsn .password , path , ** dsn .query )
405411 elif scheme == "mssql" :
406412 return MsSQL (dsn .host , dsn .port , path , dsn .user , dsn .password , thread_count = thread_count )
407413 elif scheme == "bigquery" :
0 commit comments