Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5e879fa

Browse files
committed
Queries: Added Param mechanism, to help speed up query construction.
1 parent f8d24ea commit 5e879fa

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

data_diff/queries/ast_classes.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from data_diff.utils import ArithString, join_iter
88

9-
from .compiler import Compilable, Compiler
9+
from .compiler import Compilable, Compiler, cv_params
1010
from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple
1111

1212

@@ -691,3 +691,18 @@ def compile(self, c: Compiler) -> str:
691691
class Commit(Statement):
692692
def compile(self, c: Compiler) -> str:
693693
return "COMMIT" if not c.database.is_autocommit else SKIP
694+
695+
@dataclass
696+
class Param(ExprNode, ITable):
697+
"""A value placeholder, to be specified at compilation time using the `cv_params` context variable."""
698+
699+
name: str
700+
701+
@property
702+
def source_table(self):
703+
return self
704+
705+
def compile(self, c: Compiler) -> str:
706+
params = cv_params.get()
707+
return c._compile(params[self.name])
708+

data_diff/queries/compiler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88
from data_diff.utils import ArithString
99
from data_diff.databases.database_types import AbstractDialect, DbPath
1010

11+
import contextvars
12+
13+
cv_params = contextvars.ContextVar("params")
14+
1115

1216
@dataclass
1317
class Compiler:
1418
database: AbstractDialect
19+
params: dict = {}
1520
in_select: bool = False # Compilation runtime flag
1621
in_join: bool = False # Compilation runtime flag
1722

@@ -21,7 +26,10 @@ class Compiler:
2126

2227
_counter: List = [0]
2328

24-
def compile(self, elem) -> str:
29+
def compile(self, elem, params=None) -> str:
30+
if params:
31+
cv_params.set(params)
32+
2533
res = self._compile(elem)
2634
if self.root and self._subqueries:
2735
subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items())

0 commit comments

Comments
 (0)