Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 483d284

Browse files
committed
Fix 'drop table' for Oracle
1 parent b4541f5 commit 483d284

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

data_diff/database.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def import_presto():
6363
class ConnectError(Exception):
6464
pass
6565

66+
class QueryError(Exception):
67+
pass
68+
6669

6770
def _one(seq):
6871
(x,) = seq
@@ -481,12 +484,18 @@ def __init__(self, host, port, user, password, *, database, thread_count, **kw):
481484
super().__init__(thread_count=thread_count)
482485

483486
def create_connection(self):
484-
oracle = import_oracle()
487+
self._oracle = import_oracle()
485488
try:
486-
return oracle.connect(**self.kwargs)
489+
return self._oracle.connect(**self.kwargs)
487490
except Exception as e:
488491
raise ConnectError(*e.args) from e
489492

493+
def _query(self, sql_code: str):
494+
try:
495+
return super()._query(sql_code)
496+
except self._oracle.DatabaseError as e:
497+
raise QueryError(e)
498+
490499
def md5_to_int(self, s: str) -> str:
491500
# standard_hash is faster than DBMS_CRYPTO.Hash
492501
# TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?

tests/test_database_types.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import suppress
12
import unittest
23
import time
34
from data_diff import database as db
@@ -195,6 +196,12 @@ def _insert_to_table(conn, table, values):
195196
if not isinstance(conn, db.BigQuery):
196197
conn.query("COMMIT", None)
197198

199+
def _drop_table_if_exists(conn, table):
200+
with suppress(db.QueryError):
201+
if isinstance(conn, db.Oracle):
202+
conn.query(f"DROP TABLE {table}", None)
203+
else:
204+
conn.query(f"DROP TABLE IF EXISTS {table}", None)
198205

199206
class TestDiffCrossDatabaseTables(unittest.TestCase):
200207
@parameterized.expand(type_pairs, name_func=expand_params)
@@ -212,21 +219,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
212219
src_table = src_conn.quote(".".join(src_table_path))
213220
dst_table = dst_conn.quote(".".join(dst_table_path))
214221

215-
if isinstance(src_conn, db.Oracle):
216-
src_conn.query(f"DROP TABLE {src_table}", None)
217-
else:
218-
src_conn.query(f"DROP TABLE IF EXISTS {src_table}", None)
219-
222+
_drop_table_if_exists(src_conn, src_table)
220223
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None)
221224
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))
222225

223226
values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list)
224227

225-
if isinstance(dst_conn, db.Oracle):
226-
dst_conn.query(f"DROP TABLE {dst_table}", None)
227-
else:
228-
dst_conn.query(f"DROP TABLE IF EXISTS {dst_table}", None)
229-
228+
_drop_table_if_exists(dst_conn, dst_table)
230229
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)
231230
_insert_to_table(dst_conn, dst_table, values_in_source)
232231

@@ -251,3 +250,4 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
251250

252251
duration = time.time() - start
253252
# print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms")
253+

0 commit comments

Comments
 (0)