1
1
from dataclasses import field
2
2
from datetime import datetime
3
3
from typing import Any , Generator , List , Optional , Sequence , Tuple , Union
4
+ from uuid import UUID
4
5
5
6
from runtype import dataclass
6
7
@@ -298,22 +299,29 @@ class TablePath(ExprNode, ITable):
298
299
path : DbPath
299
300
schema : Optional [Schema ] = field (default = None , repr = False )
300
301
301
- def create (self , if_not_exists = False ):
302
- if not self .schema :
303
- raise ValueError ("Schema must have a value to create table" )
304
- return CreateTable (self , if_not_exists = if_not_exists )
302
+ def create (self , source_table : ITable = None , * , if_not_exists = False ):
303
+ if source_table is None and not self .schema :
304
+ raise ValueError ("Either schema or source table needed to create table" )
305
+ if isinstance (source_table , TablePath ):
306
+ source_table = source_table .select ()
307
+ return CreateTable (self , source_table , if_not_exists = if_not_exists )
305
308
306
309
def drop (self , if_exists = False ):
307
310
return DropTable (self , if_exists = if_exists )
308
311
309
- def insert_rows (self , rows ):
312
+ def truncate (self ):
313
+ return TruncateTable (self )
314
+
315
+ def insert_rows (self , rows , * , columns = None ):
310
316
rows = list (rows )
311
- return InsertToTable (self , ConstantTable (rows ))
317
+ return InsertToTable (self , ConstantTable (rows ), columns = columns )
312
318
313
- def insert_row (self , * values ):
314
- return InsertToTable (self , ConstantTable ([values ]))
319
+ def insert_row (self , * values , columns = None ):
320
+ return InsertToTable (self , ConstantTable ([values ]), columns = columns )
315
321
316
322
def insert_expr (self , expr : Expr ):
323
+ if isinstance (expr , TablePath ):
324
+ expr = expr .select ()
317
325
return InsertToTable (self , expr )
318
326
319
327
@property
@@ -598,17 +606,21 @@ def compile(self, c: Compiler) -> str:
598
606
599
607
@dataclass
600
608
class ConstantTable (ExprNode ):
601
- rows : List [ Tuple ]
609
+ rows : Sequence [ Sequence ]
602
610
603
611
def compile (self , c : Compiler ) -> str :
604
612
raise NotImplementedError ()
605
613
606
614
def _value (self , v ):
607
- if isinstance (v , str ):
615
+ if v is None :
616
+ return "NULL"
617
+ elif isinstance (v , str ):
608
618
return f"'{ v } '"
609
619
elif isinstance (v , datetime ):
610
620
return f"timestamp '{ v } '"
611
- return str (v )
621
+ elif isinstance (v , UUID ):
622
+ return f"'{ v } '"
623
+ return repr (v )
612
624
613
625
def compile_for_insert (self , c : Compiler ):
614
626
values = ", " .join ("(%s)" % ", " .join (self ._value (v ) for v in row ) for row in self .rows )
@@ -633,11 +645,15 @@ class Statement(Compilable):
633
645
@dataclass
634
646
class CreateTable (Statement ):
635
647
path : TablePath
648
+ source_table : Expr = None
636
649
if_not_exists : bool = False
637
650
638
651
def compile (self , c : Compiler ) -> str :
639
- schema = ", " .join (f"{ k } { c .database .type_repr (v )} " for k , v in self .path .schema .items ())
640
652
ne = "IF NOT EXISTS " if self .if_not_exists else ""
653
+ if self .source_table :
654
+ return f"CREATE TABLE { ne } { c .compile (self .path )} AS { c .compile (self .source_table )} "
655
+
656
+ schema = ", " .join (f"{ c .database .quote (k )} { c .database .type_repr (v )} " for k , v in self .path .schema .items ())
641
657
return f"CREATE TABLE { ne } { c .compile (self .path )} ({ schema } )"
642
658
643
659
@@ -651,19 +667,30 @@ def compile(self, c: Compiler) -> str:
651
667
return f"DROP TABLE { ie } { c .compile (self .path )} "
652
668
653
669
670
+ @dataclass
671
+ class TruncateTable (Statement ):
672
+ path : TablePath
673
+
674
+ def compile (self , c : Compiler ) -> str :
675
+ return f"TRUNCATE TABLE { c .compile (self .path )} "
676
+
677
+
654
678
@dataclass
655
679
class InsertToTable (Statement ):
656
680
# TODO Support insert for only some columns
657
681
path : TablePath
658
682
expr : Expr
683
+ columns : List [str ] = None
659
684
660
685
def compile (self , c : Compiler ) -> str :
661
686
if isinstance (self .expr , ConstantTable ):
662
687
expr = self .expr .compile_for_insert (c )
663
688
else :
664
689
expr = c .compile (self .expr )
665
690
666
- return f"INSERT INTO { c .compile (self .path )} { expr } "
691
+ columns = f"(%s)" % ", " .join (map (c .quote , self .columns )) if self .columns is not None else ""
692
+
693
+ return f"INSERT INTO { c .compile (self .path )} { columns } { expr } "
667
694
668
695
669
696
@dataclass
0 commit comments