1- import  re 
2- 
31from  .database_types  import  * 
4- from  .base  import  Database , import_helper 
5- from  .base  import  (
6-  MD5_HEXDIGITS ,
7-  CHECKSUM_HEXDIGITS ,
8-  TIMESTAMP_PRECISION_POS ,
9-  DEFAULT_DATETIME_PRECISION ,
10- )
2+ from  .presto  import  Presto 
3+ from  .base  import  import_helper 
4+ from  .base  import  TIMESTAMP_PRECISION_POS 
115
126
137@import_helper ("trino" ) 
@@ -17,49 +11,12 @@ def import_trino():
1711 return  trino 
1812
1913
20- class  Trino (Database ):
21-  default_schema  =  "public" 
22-  TYPE_CLASSES  =  {
23-  # Timestamps 
24-  "timestamp with time zone" : TimestampTZ ,
25-  "timestamp without time zone" : Timestamp ,
26-  "timestamp" : Timestamp ,
27-  # Numbers 
28-  "integer" : Integer ,
29-  "bigint" : Integer ,
30-  "real" : Float ,
31-  "double" : Float ,
32-  # Text 
33-  "varchar" : Text ,
34-  }
35-  ROUNDS_ON_PREC_LOSS  =  True 
36- 
14+ class  Trino (Presto ):
3715 def  __init__ (self , ** kw ):
3816 trino  =  import_trino ()
3917
4018 self ._conn  =  trino .dbapi .connect (** kw )
4119
42-  def  quote (self , s : str ):
43-  return  f'"{ s }  
44- 
45-  def  md5_to_int (self , s : str ) ->  str :
46-  return  f"cast(from_base(substr(to_hex(md5(to_utf8({ s } { 1  +  MD5_HEXDIGITS  -  CHECKSUM_HEXDIGITS }  
47- 
48-  def  to_string (self , s : str ):
49-  return  f"cast({ s }  
50- 
51-  def  _query (self , sql_code : str ) ->  list :
52-  """Uses the standard SQL cursor interface""" 
53-  c  =  self ._conn .cursor ()
54-  c .execute (sql_code )
55-  if  sql_code .lower ().startswith ("select" ):
56-  return  c .fetchall ()
57-  if  re .match (r"(insert|create|truncate|drop)" , sql_code , re .IGNORECASE ):
58-  return  c .fetchone ()
59- 
60-  def  close (self ):
61-  self ._conn .close ()
62- 
6320 def  normalize_timestamp (self , value : str , coltype : TemporalType ) ->  str :
6421 if  coltype .rounds :
6522 s  =  f"date_format(cast({ value } { coltype .precision }  
@@ -70,52 +27,5 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
7027 f"RPAD(RPAD({ s } { TIMESTAMP_PRECISION_POS  +  coltype .precision } { TIMESTAMP_PRECISION_POS  +  6 }  
7128 )
7229
73-  def  normalize_number (self , value : str , coltype : FractionalType ) ->  str :
74-  return  self .to_string (f"cast({ value } { coltype .precision }  )
75- 
76-  def  select_table_schema (self , path : DbPath ) ->  str :
77-  schema , table  =  self ._normalize_table_path (path )
78- 
79-  return  (
80-  f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS " 
81-  f"WHERE table_name = '{ table } { schema }  
82-  )
83- 
84-  def  _parse_type (
85-  self ,
86-  table_path : DbPath ,
87-  col_name : str ,
88-  type_repr : str ,
89-  datetime_precision : int  =  None ,
90-  numeric_precision : int  =  None ,
91-  ) ->  ColType :
92-  timestamp_regexps  =  {
93-  r"timestamp\((\d)\)" : Timestamp ,
94-  r"timestamp\((\d)\) with time zone" : TimestampTZ ,
95-  }
96-  for  regexp , t_cls  in  timestamp_regexps .items ():
97-  m  =  re .match (regexp  +  "$" , type_repr )
98-  if  m :
99-  datetime_precision  =  int (m .group (1 ))
100-  return  t_cls (
101-  precision = datetime_precision  if  datetime_precision  is  not None  else  DEFAULT_DATETIME_PRECISION ,
102-  rounds = self .ROUNDS_ON_PREC_LOSS ,
103-  )
104- 
105-  number_regexps  =  {r"decimal\((\d+),(\d+)\)" : Decimal }
106-  for  regexp , n_cls  in  number_regexps .items ():
107-  m  =  re .match (regexp  +  "$" , type_repr )
108-  if  m :
109-  prec , scale  =  map (int , m .groups ())
110-  return  n_cls (scale )
111- 
112-  string_regexps  =  {r"varchar\((\d+)\)" : Text , r"char\((\d+)\)" : Text }
113-  for  regexp , n_cls  in  string_regexps .items ():
114-  m  =  re .match (regexp  +  "$" , type_repr )
115-  if  m :
116-  return  n_cls ()
117- 
118-  return  super ()._parse_type (table_path , col_name , type_repr , datetime_precision , numeric_precision )
119- 
12030 def  normalize_uuid (self , value : str , coltype : ColType_UUID ) ->  str :
12131 return  f"TRIM({ value }  
0 commit comments