1- from typing import Callable , List
2- from datetime import datetime
31import unittest
2+ from datetime import datetime
3+ from typing import Callable , List , Tuple
44
5- from .common import str_to_checksum , TEST_MYSQL_CONN_STRING
6- from .common import str_to_checksum , test_each_database_in_list , get_conn , random_table_suffix
7-
8- from sqeleton .queries import table , current_timestamp
5+ import pytz
96
10- from sqeleton import databases as dbs
117from sqeleton import connect
12-
8+ from sqeleton import databases as dbs
9+ from sqeleton .queries import table , current_timestamp , NormalizeAsString
10+ from .common import TEST_MYSQL_CONN_STRING
11+ from .common import str_to_checksum , test_each_database_in_list , get_conn , random_table_suffix
1312
1413TEST_DATABASES = {
1514 dbs .MySQL ,
@@ -81,6 +80,37 @@ def test_current_timestamp(self):
8180 res = db .query (current_timestamp (), datetime )
8281 assert isinstance (res , datetime ), (res , type (res ))
8382
83+ def test_correct_timezone (self ):
84+ name = "tbl_" + random_table_suffix ()
85+ db = get_conn (self .db_cls )
86+ tbl = table (db .parse_table_name (name ), schema = {
87+ "id" : int , "created_at" : "timestamp_tz(9)" , "updated_at" : "timestamp_tz(9)"
88+ })
89+
90+ db .query (tbl .create ())
91+
92+ tz = pytz .timezone ('Europe/Berlin' )
93+
94+ now = datetime .now (tz )
95+ db .query (table (db .parse_table_name (name )).insert_row ("1" , now , now ))
96+ db .query (db .dialect .set_timezone_to_utc ())
97+
98+ t = db .table (name ).query_schema ()
99+ t .schema ["created_at" ] = t .schema ["created_at" ].replace (precision = t .schema ["created_at" ].precision , rounds = True )
100+
101+ tbl = table (db .parse_table_name (name ), schema = t .schema )
102+
103+ results = db .query (tbl .select (NormalizeAsString (tbl [c ]) for c in ["created_at" , "updated_at" ]), List [Tuple ])
104+
105+ created_at = results [0 ][1 ]
106+ updated_at = results [0 ][1 ]
107+
108+ utc = now .astimezone (pytz .UTC )
109+
110+ self .assertEqual (created_at , utc .__format__ ("%Y-%m-%d %H:%M:%S.%f" ))
111+ self .assertEqual (updated_at , utc .__format__ ("%Y-%m-%d %H:%M:%S.%f" ))
112+
113+ db .query (tbl .drop ())
84114
85115@test_each_database
86116class TestThreePartIds (unittest .TestCase ):
@@ -104,3 +134,4 @@ def test_three_part_support(self):
104134 d = db .query_table_schema (part .path )
105135 assert len (d ) == 1
106136 db .query (part .drop ())
137+
0 commit comments