11"""Provides classes for performing a table diff 
22""" 
33
4+ from  abc  import  ABC , abstractmethod 
45import  time 
56from  operator  import  attrgetter , methodcaller 
67from  collections  import  defaultdict 
7- from  typing  import  List , Tuple , Iterator , Optional ,  Mapping 
8+ from  typing  import  List , Tuple , Iterator , Optional 
89import  logging 
910from  concurrent .futures  import  ThreadPoolExecutor 
1011
@@ -36,24 +37,44 @@ def parse_table_name(t):
3637 return  tuple (t .split ("." ))
3738
3839
39- class  CaseInsensitiveDict (Mapping ):
40-  def  __init__ (self , initial = ()):
41-  self ._dict  =  {k .lower (): v  for  k , v  in  dict (initial ).items ()}
40+ class  Schema (ABC ):
41+  @abstractmethod  
42+  def  get_key (self , key : str ) ->  str :
43+  ...
4244
43-  def  __setitem__ (self , key , value ):
44-  self ._dict [key .lower ()] =  value 
45+  @abstractmethod  
46+  def  __getitem__ (self , key : str ) ->  str :
47+  ...
4548
46-  def  __getitem__ (self , key ):
47-  try :
48-  return  self ._dict [key .lower ()]
49-  except  KeyError :
50-  raise 
49+  @abstractmethod  
50+  def  __setitem__ (self , key : str , value ):
51+  ...
5152
52-  def  __iter__ (self ):
53-  return  iter (self ._dict )
53+  @abstractmethod  
54+  def  __contains__ (self , key : str ) ->  bool :
55+  ...
5456
55-  def  __len__ (self ):
56-  return  len (self ._dict )
57+ 
58+ class  Schema_CaseSensitive (dict , Schema ):
59+  def  get_key (self , key ):
60+  return  key 
61+ 
62+ 
63+ class  Schema_CaseInsensitive (Schema ):
64+  def  __init__ (self , initial ):
65+  self ._dict  =  {k .lower (): (k , v ) for  k , v  in  dict (initial ).items ()}
66+ 
67+  def  get_key (self , key : str ) ->  str :
68+  return  self ._dict [key .lower ()][0 ]
69+ 
70+  def  __getitem__ (self , key : str ) ->  str :
71+  return  self ._dict [key .lower ()][1 ]
72+ 
73+  def  __setitem__ (self , key : str , value ):
74+  self ._dict [key .lower ()] =  key , value 
75+ 
76+  def  __contains__ (self , key ):
77+  return  key .lower () in  self ._dict 
5778
5879
5980@dataclass (frozen = False ) 
@@ -88,8 +109,8 @@ class TableSegment:
88109 min_update : DbTime  =  None 
89110 max_update : DbTime  =  None 
90111
91-  quote_columns : bool  =  True 
92-  _schema : Mapping [ str ,  ColType ]  =  None 
112+  case_sensitive : bool  =  True 
113+  _schema : Schema  =  None 
93114
94115 def  __post_init__ (self ):
95116 if  not  self .update_column  and  (self .min_update  or  self .max_update ):
@@ -110,17 +131,24 @@ def _update_column(self):
110131 return  self ._quote_column (self .update_column )
111132
112133 def  _quote_column (self , c ):
113-  if  self .quote_columns :
114-  return   self .database . quote (c )
115-  return  c 
134+  if  self ._schema :
135+  c   =   self ._schema . get_key (c )
136+  return  self . database . quote ( c ) 
116137
117138 def  with_schema (self ) ->  "TableSegment" :
118139 "Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema." 
119140 if  self ._schema :
120141 return  self 
121142 schema  =  self .database .query_table_schema (self .table_path )
122-  if  not  self .quote_columns :
123-  schema  =  CaseInsensitiveDict (schema )
143+  if  self .case_sensitive :
144+  schema  =  Schema_CaseSensitive (schema )
145+  else :
146+  if  len ({k .lower () for  k  in  schema }) <  len (schema ):
147+  logger .warn (
148+  f'Ambiguous schema for { self .database }  :{ "." .join (self .table_path )}   | Columns = { ", " .join (list (schema ))}  ' 
149+  )
150+  logger .warn ("We recommend to disable case-insensitivity (remove --any-case)." )
151+  schema  =  Schema_CaseInsensitive (schema )
124152 return  self .new (_schema = schema )
125153
126154 def  _make_key_range (self ):
0 commit comments