1
1
from dataclasses import field
2
2
from datetime import datetime
3
- from typing import Any , Generator , Optional , Sequence , Tuple , Union
3
+ from typing import Any , Generator , List , Optional , Sequence , Tuple , Union
4
+ from uuid import UUID
4
5
5
6
from runtype import dataclass
6
7
@@ -298,18 +299,29 @@ class TablePath(ExprNode, ITable):
298
299
path : DbPath
299
300
schema : Optional [Schema ] = field (default = None , repr = False )
300
301
301
- def create (self , if_not_exists = False ):
302
- if not self .schema :
303
- raise ValueError ("Schema must have a value to create table" )
304
- return CreateTable (self , if_not_exists = if_not_exists )
302
+ def create (self , source_table : ITable = None , * , if_not_exists = False ):
303
+ if source_table is None and not self .schema :
304
+ raise ValueError ("Either schema or source table needed to create table" )
305
+ if isinstance (source_table , TablePath ):
306
+ source_table = source_table .select ()
307
+ return CreateTable (self , source_table , if_not_exists = if_not_exists )
305
308
306
309
def drop (self , if_exists = False ):
307
310
return DropTable (self , if_exists = if_exists )
308
311
309
- def insert_values (self , rows ):
310
- raise NotImplementedError ()
312
+ def truncate (self ):
313
+ return TruncateTable (self )
314
+
315
+ def insert_rows (self , rows , * , columns = None ):
316
+ rows = list (rows )
317
+ return InsertToTable (self , ConstantTable (rows ), columns = columns )
318
+
319
+ def insert_row (self , * values , columns = None ):
320
+ return InsertToTable (self , ConstantTable ([values ]), columns = columns )
311
321
312
322
def insert_expr (self , expr : Expr ):
323
+ if isinstance (expr , TablePath ):
324
+ expr = expr .select ()
313
325
return InsertToTable (self , expr )
314
326
315
327
@property
@@ -592,6 +604,29 @@ def compile(self, c: Compiler) -> str:
592
604
return c .database .random ()
593
605
594
606
607
+ @dataclass
608
+ class ConstantTable (ExprNode ):
609
+ rows : Sequence [Sequence ]
610
+
611
+ def compile (self , c : Compiler ) -> str :
612
+ raise NotImplementedError ()
613
+
614
+ def _value (self , v ):
615
+ if v is None :
616
+ return "NULL"
617
+ elif isinstance (v , str ):
618
+ return f"'{ v } '"
619
+ elif isinstance (v , datetime ):
620
+ return f"timestamp '{ v } '"
621
+ elif isinstance (v , UUID ):
622
+ return f"'{ v } '"
623
+ return repr (v )
624
+
625
+ def compile_for_insert (self , c : Compiler ):
626
+ values = ", " .join ("(%s)" % ", " .join (self ._value (v ) for v in row ) for row in self .rows )
627
+ return f"VALUES { values } "
628
+
629
+
595
630
@dataclass
596
631
class Explain (ExprNode ):
597
632
select : Select
@@ -610,11 +645,15 @@ class Statement(Compilable):
610
645
@dataclass
611
646
class CreateTable (Statement ):
612
647
path : TablePath
648
+ source_table : Expr = None
613
649
if_not_exists : bool = False
614
650
615
651
def compile (self , c : Compiler ) -> str :
616
- schema = ", " .join (f"{ k } { c .database .type_repr (v )} " for k , v in self .path .schema .items ())
617
652
ne = "IF NOT EXISTS " if self .if_not_exists else ""
653
+ if self .source_table :
654
+ return f"CREATE TABLE { ne } { c .compile (self .path )} AS { c .compile (self .source_table )} "
655
+
656
+ schema = ", " .join (f"{ c .database .quote (k )} { c .database .type_repr (v )} " for k , v in self .path .schema .items ())
618
657
return f"CREATE TABLE { ne } { c .compile (self .path )} ({ schema } )"
619
658
620
659
@@ -628,14 +667,30 @@ def compile(self, c: Compiler) -> str:
628
667
return f"DROP TABLE { ie } { c .compile (self .path )} "
629
668
630
669
670
+ @dataclass
671
+ class TruncateTable (Statement ):
672
+ path : TablePath
673
+
674
+ def compile (self , c : Compiler ) -> str :
675
+ return f"TRUNCATE TABLE { c .compile (self .path )} "
676
+
677
+
631
678
@dataclass
632
679
class InsertToTable (Statement ):
633
680
# TODO Support insert for only some columns
634
681
path : TablePath
635
682
expr : Expr
683
+ columns : List [str ] = None
636
684
637
685
def compile (self , c : Compiler ) -> str :
638
- return f"INSERT INTO { c .compile (self .path )} { c .compile (self .expr )} "
686
+ if isinstance (self .expr , ConstantTable ):
687
+ expr = self .expr .compile_for_insert (c )
688
+ else :
689
+ expr = c .compile (self .expr )
690
+
691
+ columns = f"(%s)" % ", " .join (map (c .quote , self .columns )) if self .columns is not None else ""
692
+
693
+ return f"INSERT INTO { c .compile (self .path )} { columns } { expr } "
639
694
640
695
641
696
@dataclass
0 commit comments