Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6b533b2

Browse files
committed
Tests refactor: Convert more code to use query-builder
1 parent 1b1ccca commit 6b533b2

File tree

4 files changed

+184
-165
lines changed

4 files changed

+184
-165
lines changed

data_diff/databases/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
251251
if not text_columns:
252252
return
253253

254-
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
254+
fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns]
255255
samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(16), list)
256256
if not samples_by_row:
257257
raise ValueError(f"Table {table_path} is empty.")

data_diff/queries/ast_classes.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import field
22
from datetime import datetime
33
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
4+
from uuid import UUID
45

56
from runtype import dataclass
67

@@ -298,22 +299,29 @@ class TablePath(ExprNode, ITable):
298299
path: DbPath
299300
schema: Optional[Schema] = field(default=None, repr=False)
300301

301-
def create(self, if_not_exists=False):
302-
if not self.schema:
303-
raise ValueError("Schema must have a value to create table")
304-
return CreateTable(self, if_not_exists=if_not_exists)
302+
def create(self, source_table: ITable = None, *, if_not_exists=False):
303+
if source_table is None and not self.schema:
304+
raise ValueError("Either schema or source table needed to create table")
305+
if isinstance(source_table, TablePath):
306+
source_table = source_table.select()
307+
return CreateTable(self, source_table, if_not_exists=if_not_exists)
305308

306309
def drop(self, if_exists=False):
307310
return DropTable(self, if_exists=if_exists)
308311

309-
def insert_rows(self, rows):
312+
def truncate(self):
313+
return TruncateTable(self)
314+
315+
def insert_rows(self, rows, *, columns=None):
310316
rows = list(rows)
311-
return InsertToTable(self, ConstantTable(rows))
317+
return InsertToTable(self, ConstantTable(rows), columns=columns)
312318

313-
def insert_row(self, *values):
314-
return InsertToTable(self, ConstantTable([values]))
319+
def insert_row(self, *values, columns=None):
320+
return InsertToTable(self, ConstantTable([values]), columns=columns)
315321

316322
def insert_expr(self, expr: Expr):
323+
if isinstance(expr, TablePath):
324+
expr = expr.select()
317325
return InsertToTable(self, expr)
318326

319327
@property
@@ -598,17 +606,21 @@ def compile(self, c: Compiler) -> str:
598606

599607
@dataclass
600608
class ConstantTable(ExprNode):
601-
rows: List[Tuple]
609+
rows: Sequence[Sequence]
602610

603611
def compile(self, c: Compiler) -> str:
604612
raise NotImplementedError()
605613

606614
def _value(self, v):
607-
if isinstance(v, str):
615+
if v is None:
616+
return "NULL"
617+
elif isinstance(v, str):
608618
return f"'{v}'"
609619
elif isinstance(v, datetime):
610620
return f"timestamp '{v}'"
611-
return str(v)
621+
elif isinstance(v, UUID):
622+
return f"'{v}'"
623+
return repr(v)
612624

613625
def compile_for_insert(self, c: Compiler):
614626
values = ", ".join("(%s)" % ", ".join(self._value(v) for v in row) for row in self.rows)
@@ -633,11 +645,15 @@ class Statement(Compilable):
633645
@dataclass
634646
class CreateTable(Statement):
635647
path: TablePath
648+
source_table: Expr = None
636649
if_not_exists: bool = False
637650

638651
def compile(self, c: Compiler) -> str:
639-
schema = ", ".join(f"{k} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
640652
ne = "IF NOT EXISTS " if self.if_not_exists else ""
653+
if self.source_table:
654+
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"
655+
656+
schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
641657
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
642658

643659

@@ -651,19 +667,30 @@ def compile(self, c: Compiler) -> str:
651667
return f"DROP TABLE {ie}{c.compile(self.path)}"
652668

653669

670+
@dataclass
671+
class TruncateTable(Statement):
672+
path: TablePath
673+
674+
def compile(self, c: Compiler) -> str:
675+
return f"TRUNCATE TABLE {c.compile(self.path)}"
676+
677+
654678
@dataclass
655679
class InsertToTable(Statement):
656680
# TODO Support insert for only some columns
657681
path: TablePath
658682
expr: Expr
683+
columns: List[str] = None
659684

660685
def compile(self, c: Compiler) -> str:
661686
if isinstance(self.expr, ConstantTable):
662687
expr = self.expr.compile_for_insert(c)
663688
else:
664689
expr = c.compile(self.expr)
665690

666-
return f"INSERT INTO {c.compile(self.path)} {expr}"
691+
columns = f"(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else ""
692+
693+
return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
667694

668695

669696
@dataclass

0 commit comments

Comments
 (0)