77from .sql import SqlOrStr , Compiler
88
99
10- logger = logging .getLogger (' database' )
10+ logger = logging .getLogger (" database" )
1111
1212
1313def import_postgres ():
1414 import psycopg2
1515 import psycopg2 .extras
16+
1617 psycopg2 .extensions .set_wait_callback (psycopg2 .extras .wait_select )
1718 return psycopg2
1819
20+
1921def import_mysql ():
2022 import mysql .connector
23+
2124 return mysql .connector
2225
26+
2327def import_snowflake ():
2428 import snowflake .connector
29+
2530 return snowflake
2631
32+
2733def import_mssql ():
2834 import pymssql
35+
2936 return pymssql
3037
38+
3139def import_oracle ():
3240 import cx_Oracle
41+
3342 return cx_Oracle
3443
3544
@@ -38,12 +47,13 @@ class ConnectError(Exception):
3847
3948
4049def _one (seq ):
41- x , = seq
50+ ( x ,) = seq
4251 return x
4352
53+
4454class Database (ABC ):
4555 """Base abstract class for databases.
46-
56+
4757 Used for providing connection code and implementation specific SQL utilities.
4858 """
4959
@@ -60,10 +70,10 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
6070 res = self ._query (sql_code )
6171 if res_type is int :
6272 res = _one (_one (res ))
63- if res is None : # May happen due to sum() of 0 items
73+ if res is None : # May happen due to sum() of 0 items
6474 return None
6575 return int (res )
66- elif getattr (res_type , ' __origin__' , None ) is list and len (res_type .__args__ ) == 1 :
76+ elif getattr (res_type , " __origin__" , None ) is list and len (res_type .__args__ ) == 1 :
6777 if res_type .__args__ == (int ,):
6878 return [_one (row ) for row in res ]
6979 elif res_type .__args__ == (Tuple ,):
@@ -87,12 +97,14 @@ def md5_to_int(self, s: str) -> str:
8797 "Provide SQL for computing md5 and returning an int"
8898 ...
8999
90- CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower
100+
101+ CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower
91102MD5_HEXDIGITS = 32
92103
93- _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
104+ _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
94105CHECKSUM_MASK = (2 ** _CHECKSUM_BITSIZE ) - 1
95106
107+
96108class Postgres (Database ):
97109 def __init__ (self , host , port , database , user , password ):
98110 postgres = import_postgres ()
@@ -110,18 +122,18 @@ def md5_to_int(self, s: str) -> str:
110122 return f"('x' || substring(md5({ s } ), { 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS } ))::bit({ _CHECKSUM_BITSIZE } )::bigint"
111123
112124 def to_string (self , s : str ):
113- return f' { s } ::varchar'
125+ return f" { s } ::varchar"
114126
115127
116128class MySQL (Database ):
117129 def __init__ (self , host , port , database , user , password ):
118130 mysql = import_mysql ()
119131
120132 args = dict (host = host , port = port , database = database , user = user , password = password )
121- self ._args = {k :v for k , v in args .items () if v is not None }
133+ self ._args = {k : v for k , v in args .items () if v is not None }
122134
123135 try :
124- self ._conn = mysql .connect (charset = ' utf8' , use_unicode = True , ** self ._args )
136+ self ._conn = mysql .connect (charset = " utf8" , use_unicode = True , ** self ._args )
125137 except mysql .Error as e :
126138 if e .errno == mysql .errorcode .ER_ACCESS_DENIED_ERROR :
127139 raise ConnectError ("Bad user name or password" ) from e
@@ -131,13 +143,13 @@ def __init__(self, host, port, database, user, password):
131143 raise ConnectError (* e .args ) from e
132144
133145 def quote (self , s : str ):
134- return f' `{ s } `'
146+ return f" `{ s } `"
135147
136148 def md5_to_int (self , s : str ) -> str :
137149 return f"cast(conv(substring(md5({ s } ), { 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS } ), 16, 10) as unsigned)"
138150
139151 def to_string (self , s : str ):
140- return f' cast({ s } as char)'
152+ return f" cast({ s } as char)"
141153
142154
143155class Oracle (Database ):
@@ -155,31 +167,33 @@ def md5_to_int(self, s: str) -> str:
155167 return f"to_number(substr(standard_hash({ s } , 'MD5'), 18), 'xxxxxxxxxxxxxxx')"
156168
157169 def quote (self , s : str ):
158- return f' { s } '
170+ return f" { s } "
159171
160172 def to_string (self , s : str ):
161- return f'cast({ s } as varchar(1024))'
173+ return f"cast({ s } as varchar(1024))"
174+
162175
163176class Redshift (Postgres ):
164177 def md5_to_int (self , s : str ) -> str :
165178 return f"strtol(substring(md5({ s } ), { 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS } ), 16)::decimal(38)"
166179
180+
167181class MsSQL (Database ):
168182 "AKA sql-server"
169183
170184 def __init__ (self , host , port , database , user , password ):
171185 mssql = import_mssql ()
172186
173187 args = dict (server = host , port = port , database = database , user = user , password = password )
174- self ._args = {k :v for k , v in args .items () if v is not None }
188+ self ._args = {k : v for k , v in args .items () if v is not None }
175189
176190 try :
177191 self ._conn = mssql .connect (** self ._args )
178192 except mssql .Error as e :
179193 raise ConnectError (* e .args ) from e
180194
181195 def quote (self , s : str ):
182- return f' [{ s } ]'
196+ return f" [{ s } ]"
183197
184198 def md5_to_int (self , s : str ) -> str :
185199 return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', { s } ), 2))"
@@ -188,13 +202,15 @@ def md5_to_int(self, s: str) -> str:
188202 def to_string (self , s : str ):
189203 return f"CONVERT(varchar, { s } )"
190204
205+
191206class BigQuery (Database ):
192207 def __init__ (self , project , dataset ):
193208 from google .cloud import bigquery
209+
194210 self ._client = bigquery .Client (project )
195211
196212 def quote (self , s : str ):
197- return f' `{ s } `'
213+ return f" `{ s } `"
198214
199215 def md5_to_int (self , s : str ) -> str :
200216 return f"cast(cast( ('0x' || substr(TO_HEX(md5({ s } )), 18)) as int64) as numeric)"
@@ -206,29 +222,27 @@ def _canonize_value(self, value):
206222
207223 def _query (self , sql_code : str ):
208224 from google .cloud import bigquery
225+
209226 try :
210227 res = list (self ._client .query (sql_code ))
211228 except Exception as e :
212229 msg = "Exception when trying to execute SQL code:\n %s\n \n Got error: %s"
213- raise ConnectError (msg % (sql_code , e ))
230+ raise ConnectError (msg % (sql_code , e ))
214231
215232 if res and isinstance (res [0 ], bigquery .table .Row ):
216233 res = [tuple (self ._canonize_value (v ) for v in row .values ()) for row in res ]
217234 return res
218235
219236 def to_string (self , s : str ):
220- return f'cast({ s } as string)'
237+ return f"cast({ s } as string)"
238+
221239
222240class Snowflake (Database ):
223241 def __init__ (self , account , user , password , path , schema , database , print_sql = False ):
224242 snowflake = import_snowflake ()
225- logging .getLogger (' snowflake.connector' ).setLevel (logging .WARNING )
243+ logging .getLogger (" snowflake.connector" ).setLevel (logging .WARNING )
226244
227- self ._conn = snowflake .connector .connect (
228- user = user ,
229- password = password ,
230- account = account
231- )
245+ self ._conn = snowflake .connector .connect (user = user , password = password , account = account )
232246 self ._conn .cursor ().execute (f"USE WAREHOUSE { path .lstrip ('/' )} " )
233247 self ._conn .cursor ().execute (f"USE DATABASE { database } " )
234248 self ._conn .cursor ().execute (f"USE SCHEMA { schema } " )
@@ -240,7 +254,7 @@ def md5_to_int(self, s: str) -> str:
240254 return f"BITAND(md5_number_lower64({ s } ), { CHECKSUM_MASK } )"
241255
242256 def to_string (self , s : str ):
243- return f' cast({ s } as string)'
257+ return f" cast({ s } as string)"
244258
245259
246260def connect_to_uri (db_uri : str ) -> Database :
@@ -259,28 +273,28 @@ def connect_to_uri(db_uri: str) -> Database:
259273 dsn = dsnparse .parse (db_uri )
260274 if len (dsn .schemes ) > 1 :
261275 raise NotImplementedError ("No support for multiple schemes" )
262- scheme , = dsn .schemes
276+ ( scheme ,) = dsn .schemes
263277
264278 if len (dsn .paths ) == 0 :
265- path = ''
279+ path = ""
266280 elif len (dsn .paths ) == 1 :
267- path , = dsn .paths
281+ ( path ,) = dsn .paths
268282 else :
269283 raise ValueError ("Bad value for uri, too many paths: %s" % db_uri )
270284
271- if scheme == ' postgres' :
285+ if scheme == " postgres" :
272286 return Postgres (dsn .host , dsn .port , path , dsn .user , dsn .password )
273- elif scheme == ' mysql' :
287+ elif scheme == " mysql" :
274288 return MySQL (dsn .host , dsn .port , path , dsn .user , dsn .password )
275- elif scheme == ' snowflake' :
289+ elif scheme == " snowflake" :
276290 return Snowflake (dsn .host , dsn .user , dsn .password , path , ** dsn .query )
277- elif scheme == ' mssql' :
291+ elif scheme == " mssql" :
278292 return MsSQL (dsn .host , dsn .port , path , dsn .user , dsn .password )
279- elif scheme == ' bigquery' :
293+ elif scheme == " bigquery" :
280294 return BigQuery (dsn .host , path )
281- elif scheme == ' redshift' :
295+ elif scheme == " redshift" :
282296 return Redshift (dsn .host , dsn .port , path , dsn .user , dsn .password )
283- elif scheme == ' oracle' :
297+ elif scheme == " oracle" :
284298 return Oracle (dsn .host , dsn .port , path , dsn .user , dsn .password )
285299
286300 raise NotImplementedError (f"Scheme { dsn .scheme } currently not supported" )
0 commit comments