Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e5b34f3

Browse files
committed
Convert timezone
Changing the session timezone does not change the timestamp output, therefore converting the timezone directly to UTC.
1 parent 3815892 commit e5b34f3

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

sqeleton/databases/snowflake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def md5_as_int(self, s: str) -> str:
4848
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
4949
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
5050
if coltype.rounds:
51-
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
51+
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))"
5252
else:
53-
timestamp = f"cast({value} as timestamp({coltype.precision}))"
53+
timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))"
5454

5555
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
5656

tests/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import logging
88
import subprocess
99

10+
import sqeleton
1011
from parameterized import parameterized_class
1112

1213
from sqeleton import databases as db
1314
from sqeleton import connect
15+
from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue
1416
from sqeleton.queries import table
1517
from sqeleton.databases import Database
1618
from sqeleton.query_utils import drop_table
@@ -83,7 +85,8 @@ def get_conn(cls: type, shared: bool = True) -> Database:
8385
_database_instances[cls] = get_conn(cls, shared=False)
8486
return _database_instances[cls]
8587

86-
return connect(CONN_STRINGS[cls], N_THREADS)
88+
con = sqeleton.connect.load_mixins(AbstractMixin_NormalizeValue)
89+
return con(CONN_STRINGS[cls], N_THREADS)
8790

8891

8992
def _print_used_dbs():

tests/test_database.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
from typing import Callable, List
2-
from datetime import datetime
31
import unittest
2+
from datetime import datetime
3+
from typing import Callable, List, Tuple
44

5-
from .common import str_to_checksum, TEST_MYSQL_CONN_STRING
6-
from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix
7-
8-
from sqeleton.queries import table, current_timestamp
5+
import pytz
96

10-
from sqeleton import databases as dbs
117
from sqeleton import connect
12-
8+
from sqeleton import databases as dbs
9+
from sqeleton.queries import table, current_timestamp, NormalizeAsString
10+
from .common import TEST_MYSQL_CONN_STRING
11+
from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix
1312

1413
TEST_DATABASES = {
1514
dbs.MySQL,
@@ -81,6 +80,37 @@ def test_current_timestamp(self):
8180
res = db.query(current_timestamp(), datetime)
8281
assert isinstance(res, datetime), (res, type(res))
8382

83+
def test_correct_timezone(self):
84+
name = "tbl_" + random_table_suffix()
85+
db = get_conn(self.db_cls)
86+
tbl = table(db.parse_table_name(name), schema={
87+
"id": int, "created_at": "timestamp_tz(9)", "updated_at": "timestamp_tz(9)"
88+
})
89+
90+
db.query(tbl.create())
91+
92+
tz = pytz.timezone('Europe/Berlin')
93+
94+
now = datetime.now(tz)
95+
db.query(table(db.parse_table_name(name)).insert_row("1", now, now))
96+
db.query(db.dialect.set_timezone_to_utc())
97+
98+
t = db.table(name).query_schema()
99+
t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision, rounds=True)
100+
101+
tbl = table(db.parse_table_name(name), schema=t.schema)
102+
103+
results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple])
104+
105+
created_at = results[0][1]
106+
updated_at = results[0][1]
107+
108+
utc = now.astimezone(pytz.UTC)
109+
110+
self.assertEqual(created_at, utc.__format__("%Y-%m-%d %H:%M:%S.%f"))
111+
self.assertEqual(updated_at, utc.__format__("%Y-%m-%d %H:%M:%S.%f"))
112+
113+
db.query(tbl.drop())
84114

85115
@test_each_database
86116
class TestThreePartIds(unittest.TestCase):
@@ -104,3 +134,4 @@ def test_three_part_support(self):
104134
d = db.query_table_schema(part.path)
105135
assert len(d) == 1
106136
db.query(part.drop())
137+

0 commit comments

Comments
 (0)