Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
14 changes: 12 additions & 2 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Type
from typing import Any, ClassVar, Dict, Type, Union

import attrs

Expand All @@ -19,7 +19,8 @@
ThreadedDatabase,
import_helper,
ConnectError,
BaseDialect,
BaseDialect,
ThreadLocalInterpreter,
)
from data_diff.databases.base import (
MD5_HEXDIGITS,
Expand Down Expand Up @@ -148,3 +149,12 @@ def create_connection(self):
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
raise ConnectError("Database does not exist") from e
raise ConnectError(*e.args) from e

def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
"This method runs in a worker thread"
if self._init_error:
raise self._init_error
if not self.thread_local.conn.is_connected():
self.thread_local.conn.ping(reconnect=True, attempts=3, delay=5)
return self._query_conn(self.thread_local.conn, sql_code)

3 changes: 2 additions & 1 deletion data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, ClassVar, Dict, List, Type

from urllib.parse import unquote
import attrs

from data_diff.abcs.database_types import (
Expand Down Expand Up @@ -168,6 +168,7 @@ def create_connection(self):

pg = import_postgresql()
try:
self._args["password"] = unquote(self._args["password"])
self._conn = pg.connect(
**self._args, keepalives=1, keepalives_idle=5, keepalives_interval=2, keepalives_count=2
)
Expand Down
43 changes: 42 additions & 1 deletion tests/test_postgresql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest

from urllib.parse import quote
from data_diff.queries.api import table, commit
from data_diff import TableSegment, HashDiffer
from data_diff import databases as db
from tests.common import get_conn, random_table_suffix
from tests.common import get_conn, random_table_suffix, connect
from data_diff import connect_to_table


class TestUUID(unittest.TestCase):
Expand Down Expand Up @@ -113,3 +115,42 @@ def test_100_fields(self):
id_ = diff[0][1][0]
result = (id_,) + tuple("1" for x in range(100))
self.assertEqual(diff, [("-", result)])


class TestSpecialCharacterPassword(unittest.TestCase):
def setUp(self) -> None:
self.connection = get_conn(db.PostgreSQL)

table_suffix = random_table_suffix()

self.table_name = f"table{table_suffix}"
self.table = table(self.table_name)

def test_special_char_password(self):
password = "passw!!!@rd"
# Setup user with special character '@' in password
self.connection.query("DROP USER IF EXISTS test;", None)
self.connection.query(f"CREATE USER test WITH PASSWORD '{password}';", None)

password_quoted = quote(password)
db_config = {
"driver": "postgresql",
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "test",
"password": password_quoted,
}

# verify pythonic connection method
connect_to_table(
db_config,
self.table_name,
)

# verify connection method with URL string unquoted after it's verified
db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"

connection_verified = connect(db_url)
assert connection_verified._args.get('password') == password