Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 04ab11b

Browse files
committed
squash add pg 3 part id support
1 parent 2fe17a1 commit 04ab11b

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

sqeleton/databases/postgresql.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ..abcs.database_types import (
2+
DbPath,
23
Timestamp,
34
TimestampTZ,
45
Float,
@@ -122,3 +123,27 @@ def create_connection(self):
122123
return c
123124
except pg.OperationalError as e:
124125
raise ConnectError(*e.args) from e
126+
127+
def select_table_schema(self, path: DbPath) -> str:
128+
database, schema, table = self._normalize_table_path(path)
129+
130+
info_schema_path = ["information_schema", "columns"]
131+
if database:
132+
info_schema_path.insert(0, database)
133+
134+
return (
135+
f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} "
136+
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
137+
)
138+
139+
def _normalize_table_path(self, path: DbPath) -> DbPath:
140+
if len(path) == 1:
141+
return None, self.default_schema, path[0]
142+
elif len(path) == 2:
143+
return None, path[0], path[1]
144+
elif len(path) == 3:
145+
return path
146+
147+
raise ValueError(
148+
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
149+
)

tests/test_database.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,30 @@ def test_current_timestamp(self):
8080
db = get_conn(self.db_cls)
8181
res = db.query(current_timestamp(), datetime)
8282
assert isinstance(res, datetime), (res, type(res))
83+
84+
85+
@test_each_database
86+
class TestThreePartIds(unittest.TestCase):
87+
def test_three_part_support(self):
88+
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake]:
89+
self.skipTest('Limited support for 3 part ids')
90+
91+
table_name = "tbl_" + random_table_suffix()
92+
db = get_conn(self.db_cls)
93+
db_res = db.query("SELECT CURRENT_DATABASE()")
94+
schema_res = db.query("SELECT CURRENT_SCHEMA()")
95+
db_name = db_res.rows[0][0]
96+
schema_name = schema_res.rows[0][0]
97+
98+
table_one_part = table((table_name,), schema={"id": int})
99+
table_two_part = table((schema_name, table_name), schema={"id": int})
100+
table_three_part = table((db_name, schema_name, table_name), schema={"id": int})
101+
102+
db.query(table_one_part.create())
103+
db.query(table_one_part.drop())
104+
105+
db.query(table_two_part.create())
106+
db.query(table_two_part.drop())
107+
108+
db.query(table_three_part.create())
109+
db.query(table_three_part.drop())

0 commit comments

Comments
 (0)