Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit fb0d8e6

Browse files
authored
Merge pull request #421 from datafold/optional_dbt
Make DBT dependency optional
2 parents 584206b + 1922214 commit fb0d8e6

File tree

3 files changed

+67
-63
lines changed

3 files changed

+67
-63
lines changed

data_diff/dbt.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33
import os
44
import time
55
import rich
6-
import yaml
76
from dataclasses import dataclass
87
from packaging.version import parse as parse_version
98
from typing import List, Optional, Dict
109

1110
import requests
12-
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
13-
from dbt.config.renderer import ProfileRenderer
11+
12+
def import_dbt():
13+
try:
14+
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
15+
from dbt.config.renderer import ProfileRenderer
16+
import yaml
17+
except ImportError:
18+
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")
19+
20+
return parse_run_results, parse_manifest, ProfileRenderer, yaml
1421

1522
from .tracking import (
1623
set_entrypoint_name,
@@ -263,13 +270,15 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
263270
self.project_dict = None
264271
self.requires_upper = False
265272

273+
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
274+
266275
def get_datadiff_variables(self) -> dict:
267276
return self.project_dict.get("vars").get("data_diff")
268277

269278
def get_models(self):
270279
with open(self.project_dir + RUN_RESULTS_PATH) as run_results:
271280
run_results_dict = json.load(run_results)
272-
run_results_obj = parse_run_results(run_results=run_results_dict)
281+
run_results_obj = self.parse_run_results(run_results=run_results_dict)
273282

274283
dbt_version = parse_version(run_results_obj.metadata.dbt_version)
275284

@@ -280,7 +289,7 @@ def get_models(self):
280289

281290
with open(self.project_dir + MANIFEST_PATH) as manifest:
282291
manifest_dict = json.load(manifest)
283-
manifest_obj = parse_manifest(manifest=manifest_dict)
292+
manifest_obj = self.parse_manifest(manifest=manifest_dict)
284293

285294
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
286295
models = [manifest_obj.nodes.get(x) for x in success_models]
@@ -295,11 +304,11 @@ def get_primary_keys(self, model):
295304

296305
def set_project_dict(self):
297306
with open(self.project_dir + PROJECT_FILE) as project:
298-
self.project_dict = yaml.safe_load(project)
307+
self.project_dict = self.yaml.safe_load(project)
299308

300309
def set_connection(self):
301310
with open(self.profiles_dir + PROFILES_FILE) as profiles:
302-
profiles = yaml.safe_load(profiles)
311+
profiles = self.yaml.safe_load(profiles)
303312

304313
dbt_profile = self.project_dict.get("profile")
305314
profile_outputs = profiles.get(dbt_profile)
@@ -308,7 +317,7 @@ def set_connection(self):
308317
conn_type = credentials.get("type").lower()
309318

310319
# values can contain env_vars
311-
rendered_credentials = ProfileRenderer().render_data(credentials)
320+
rendered_credentials = self.ProfileRenderer().render_data(credentials)
312321

313322
if conn_type == "snowflake":
314323
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ trino = {version="^0.314.0", optional=true}
3838
presto-python-client = {version="*", optional=true}
3939
clickhouse-driver = {version="*", optional=true}
4040
duckdb = {version="^0.6.0", optional=true}
41-
dbt-artifacts-parser = "^0.2.4"
42-
dbt-core = "^1.0.0"
41+
dbt-artifacts-parser = {version="^0.2.4", optional=true}
42+
dbt-core = {version="^1.0.0", optional=true}
4343

4444
[tool.poetry.dev-dependencies]
4545
parameterized = "*"
@@ -54,6 +54,8 @@ presto-python-client = "*"
5454
clickhouse-driver = "*"
5555
vertica-python = "*"
5656
duckdb = "^0.6.0"
57+
dbt-artifacts-parser = "^0.2.4"
58+
dbt-core = "^1.0.0"
5759
# google-cloud-bigquery = "*"
5860
# databricks-sql-connector = "*"
5961

@@ -70,6 +72,7 @@ trino = ["trino"]
7072
clickhouse = ["clickhouse-driver"]
7173
vertica = ["vertica-python"]
7274
duckdb = ["duckdb"]
75+
dbt = ["dbt-core", "dbt-artifacts-parser"]
7376

7477
[build-system]
7578
requires = ["poetry-core>=1.0.0"]

tests/test_dbt.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33

4+
import yaml
45
from data_diff.diff_tables import Algorithm
56
from .test_cli import run_datadiff_cli
67

@@ -49,112 +50,102 @@ def test_get_datadiff_variables_empty(self):
4950
DbtParser.get_datadiff_variables(mock_self)
5051

5152
@patch("builtins.open", new_callable=mock_open, read_data="{}")
52-
@patch("data_diff.dbt.parse_run_results")
53-
@patch("data_diff.dbt.parse_manifest")
54-
def test_get_models(self, mock_manifest_parser, mock_run_parser, mock_open):
53+
def test_get_models(self, mock_open):
5554
expected_value = "expected_value"
5655
mock_self = Mock()
5756
mock_self.project_dir = ""
5857
mock_run_results = Mock()
5958
mock_success_result = Mock()
6059
mock_failed_result = Mock()
6160
mock_manifest = Mock()
62-
mock_run_parser.return_value = mock_run_results
61+
mock_self.parse_run_results.return_value = mock_run_results
6362
mock_run_results.metadata.dbt_version = "1.0.0"
6463
mock_success_result.unique_id = "success_unique_id"
6564
mock_failed_result.unique_id = "failed_unique_id"
6665
mock_success_result.status.name = "success"
6766
mock_failed_result.status.name = "failed"
6867
mock_run_results.results = [mock_success_result, mock_failed_result]
69-
mock_manifest_parser.return_value = mock_manifest
68+
mock_self.parse_manifest.return_value = mock_manifest
7069
mock_manifest.nodes = {"success_unique_id": expected_value}
7170

7271
models = DbtParser.get_models(mock_self)
7372

7473
self.assertEqual(expected_value, models[0])
7574
mock_open.assert_any_call(RUN_RESULTS_PATH)
7675
mock_open.assert_any_call(MANIFEST_PATH)
77-
mock_run_parser.assert_called_once_with(run_results={})
78-
mock_manifest_parser.assert_called_once_with(manifest={})
76+
mock_self.parse_run_results.assert_called_once_with(run_results={})
77+
mock_self.parse_manifest.assert_called_once_with(manifest={})
7978

8079
@patch("builtins.open", new_callable=mock_open, read_data="{}")
81-
@patch("data_diff.dbt.parse_run_results")
82-
@patch("data_diff.dbt.parse_manifest")
83-
def test_get_models_bad_lower_dbt_version(self, mock_manifest_parser, mock_run_parser, mock_open):
80+
def test_get_models_bad_lower_dbt_version(self, mock_open):
8481
mock_self = Mock()
8582
mock_self.project_dir = ""
8683
mock_run_results = Mock()
87-
mock_run_parser.return_value = mock_run_results
84+
mock_self.parse_run_results.return_value = mock_run_results
8885
mock_run_results.metadata.dbt_version = "0.19.0"
8986

9087
with self.assertRaises(Exception) as ex:
9188
DbtParser.get_models(mock_self)
9289

9390
mock_open.assert_called_once_with(RUN_RESULTS_PATH)
94-
mock_run_parser.assert_called_once_with(run_results={})
95-
mock_manifest_parser.assert_not_called()
91+
mock_self.parse_run_results.assert_called_once_with(run_results={})
92+
mock_self.parse_manifest.assert_not_called()
9693
self.assertIn("version to be", ex.exception.args[0])
9794

9895
@patch("builtins.open", new_callable=mock_open, read_data="{}")
99-
@patch("data_diff.dbt.parse_run_results")
100-
@patch("data_diff.dbt.parse_manifest")
101-
def test_get_models_bad_upper_dbt_version(self, mock_manifest_parser, mock_run_parser, mock_open):
96+
def test_get_models_bad_upper_dbt_version(self, mock_open):
10297
mock_self = Mock()
10398
mock_self.project_dir = ""
10499
mock_run_results = Mock()
105-
mock_run_parser.return_value = mock_run_results
100+
mock_self.parse_run_results.return_value = mock_run_results
106101
mock_run_results.metadata.dbt_version = "1.5.1"
107102

108103
with self.assertRaises(Exception) as ex:
109104
DbtParser.get_models(mock_self)
110105

111106
mock_open.assert_called_once_with(RUN_RESULTS_PATH)
112-
mock_run_parser.assert_called_once_with(run_results={})
113-
mock_manifest_parser.assert_not_called()
107+
mock_self.parse_run_results.assert_called_once_with(run_results={})
108+
mock_self.parse_manifest.assert_not_called()
114109
self.assertIn("version to be", ex.exception.args[0])
115110

116111
@patch("builtins.open", new_callable=mock_open, read_data="{}")
117-
@patch("data_diff.dbt.parse_run_results")
118-
@patch("data_diff.dbt.parse_manifest")
119-
def test_get_models_no_success(self, mock_manifest_parser, mock_run_parser, mock_open):
112+
def test_get_models_no_success(self, mock_open):
120113
mock_self = Mock()
121114
mock_self.project_dir = ""
122115
mock_run_results = Mock()
123116
mock_success_result = Mock()
124117
mock_failed_result = Mock()
125118
mock_manifest = Mock()
126-
mock_run_parser.return_value = mock_run_results
119+
mock_self.parse_run_results.return_value = mock_run_results
127120
mock_run_results.metadata.dbt_version = "1.0.0"
128121
mock_failed_result.unique_id = "failed_unique_id"
129122
mock_success_result.status.name = "success"
130123
mock_failed_result.status.name = "failed"
131124
mock_run_results.results = [mock_failed_result]
132-
mock_manifest_parser.return_value = mock_manifest
125+
mock_self.parse_manifest.return_value = mock_manifest
133126
mock_manifest.nodes = {"success_unique_id": "a_unique_id"}
134127

135128
with self.assertRaises(Exception):
136129
DbtParser.get_models(mock_self)
137130

138131
mock_open.assert_any_call(RUN_RESULTS_PATH)
139132
mock_open.assert_any_call(MANIFEST_PATH)
140-
mock_run_parser.assert_called_once_with(run_results={})
141-
mock_manifest_parser.assert_called_once_with(manifest={})
133+
mock_self.parse_run_results.assert_called_once_with(run_results={})
134+
mock_self.parse_manifest.assert_called_once_with(manifest={})
142135

143-
@patch("data_diff.dbt.yaml.safe_load")
144136
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
145-
def test_set_project_dict(self, mock_open, mock_yaml_parse):
137+
def test_set_project_dict(self, mock_open):
146138
expected_dict = {"key1": "value1"}
147139
mock_self = Mock()
148140
mock_self.project_dir = ""
149-
mock_yaml_parse.return_value = expected_dict
141+
mock_self.yaml.safe_load.return_value = expected_dict
150142
DbtParser.set_project_dict(mock_self)
151143

152144
self.assertEqual(mock_self.project_dict, expected_dict)
153145
mock_open.assert_called_once_with(PROJECT_FILE)
154146

155-
@patch("data_diff.dbt.yaml.safe_load")
156147
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
157-
def test_set_connection_snowflake(self, mock_open_file, mock_yaml_parse):
148+
def test_set_connection_snowflake(self, mock_open_file):
158149
expected_driver = "snowflake"
159150
expected_password = "password_value"
160151
profiles_dict = {
@@ -172,19 +163,19 @@ def test_set_connection_snowflake(self, mock_open_file, mock_yaml_parse):
172163
mock_self = Mock()
173164
mock_self.profiles_dir = ""
174165
mock_self.project_dict = {"profile": "profile_name"}
175-
mock_yaml_parse.return_value = profiles_dict
166+
mock_self.yaml.safe_load.return_value = profiles_dict
167+
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
176168
DbtParser.set_connection(mock_self)
177169

178170
self.assertIsInstance(mock_self.connection, dict)
179171
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
180172
self.assertEqual(mock_self.connection.get("password"), expected_password)
181173
self.assertEqual(mock_self.requires_upper, True)
182174
mock_open_file.assert_called_once_with(PROFILES_FILE)
183-
mock_yaml_parse.assert_called_once_with(mock_open_file())
175+
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
184176

185-
@patch("data_diff.dbt.yaml.safe_load")
186177
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
187-
def test_set_connection_snowflake_no_password(self, mock_open_file, mock_yaml_parse):
178+
def test_set_connection_snowflake_no_password(self, mock_open_file):
188179
expected_driver = "snowflake"
189180
profiles_dict = {
190181
"profile_name": {
@@ -196,18 +187,18 @@ def test_set_connection_snowflake_no_password(self, mock_open_file, mock_yaml_pa
196187
mock_self = Mock()
197188
mock_self.profiles_dir = ""
198189
mock_self.project_dict = {"profile": "profile_name"}
199-
mock_yaml_parse.return_value = profiles_dict
190+
mock_self.yaml.safe_load.return_value = profiles_dict
191+
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
200192

201193
with self.assertRaises(Exception):
202194
DbtParser.set_connection(mock_self)
203195

204196
mock_open_file.assert_called_once_with(PROFILES_FILE)
205-
mock_yaml_parse.assert_called_once_with(mock_open_file())
197+
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
206198
self.assertNotIsInstance(mock_self.connection, dict)
207199

208-
@patch("data_diff.dbt.yaml.safe_load")
209200
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
210-
def test_set_connection_bigquery(self, mock_open_file, mock_yaml_parse):
201+
def test_set_connection_bigquery(self, mock_open_file):
211202
expected_driver = "bigquery"
212203
expected_method = "oauth"
213204
expected_project = "a_project"
@@ -229,19 +220,19 @@ def test_set_connection_bigquery(self, mock_open_file, mock_yaml_parse):
229220
mock_self = Mock()
230221
mock_self.profiles_dir = ""
231222
mock_self.project_dict = {"profile": "profile_name"}
232-
mock_yaml_parse.return_value = profiles_dict
223+
mock_self.yaml.safe_load.return_value = profiles_dict
224+
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
233225
DbtParser.set_connection(mock_self)
234226

235227
self.assertIsInstance(mock_self.connection, dict)
236228
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
237229
self.assertEqual(mock_self.connection.get("project"), expected_project)
238230
self.assertEqual(mock_self.connection.get("dataset"), expected_dataset)
239231
mock_open_file.assert_called_once_with(PROFILES_FILE)
240-
mock_yaml_parse.assert_called_once_with(mock_open_file())
232+
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
241233

242-
@patch("data_diff.dbt.yaml.safe_load")
243234
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
244-
def test_set_connection_bigquery_not_oauth(self, mock_open_file, mock_yaml_parse):
235+
def test_set_connection_bigquery_not_oauth(self, mock_open_file):
245236
expected_driver = "bigquery"
246237
expected_method = "not_oauth"
247238
expected_project = "a_project"
@@ -263,17 +254,17 @@ def test_set_connection_bigquery_not_oauth(self, mock_open_file, mock_yaml_parse
263254
mock_self = Mock()
264255
mock_self.profiles_dir = ""
265256
mock_self.project_dict = {"profile": "profile_name"}
266-
mock_yaml_parse.return_value = profiles_dict
257+
mock_self.yaml.safe_load.return_value = profiles_dict
258+
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
267259
with self.assertRaises(Exception):
268260
DbtParser.set_connection(mock_self)
269261

270262
mock_open_file.assert_called_once_with(PROFILES_FILE)
271-
mock_yaml_parse.assert_called_once_with(mock_open_file())
263+
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
272264
self.assertNotIsInstance(mock_self.connection, dict)
273265

274-
@patch("data_diff.dbt.yaml.safe_load")
275266
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
276-
def test_set_connection_key_error(self, mock_open_file, mock_yaml_parse):
267+
def test_set_connection_key_error(self, mock_open_file):
277268
profiles_dict = {
278269
"profile_name": {
279270
"outputs": {
@@ -290,17 +281,17 @@ def test_set_connection_key_error(self, mock_open_file, mock_yaml_parse):
290281
mock_self.profiles_dir = ""
291282
mock_self.project_dir = ""
292283
mock_self.project_dict = {"profile": "bad_key"}
293-
mock_yaml_parse.return_value = profiles_dict
284+
mock_self.yaml.safe_load.return_value = profiles_dict
285+
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
294286
with self.assertRaises(Exception):
295287
DbtParser.set_connection(mock_self)
296288

297289
mock_open_file.assert_called_once_with(PROFILES_FILE)
298-
mock_yaml_parse.assert_called_once_with(mock_open_file())
290+
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
299291
self.assertNotIsInstance(mock_self.connection, dict)
300292

301-
@patch("data_diff.dbt.yaml.safe_load")
302293
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
303-
def test_set_connection_not_implemented(self, mock_open_file, mock_yaml_parse):
294+
def test_set_connection_not_implemented(self, mock_open_file):
304295
expected_driver = "not_implemented"
305296
profiles_dict = {
306297
"profile_name": {
@@ -317,12 +308,13 @@ def test_set_connection_not_implemented(self, mock_open_file, mock_yaml_parse):
317308
mock_self.profiles_dir = ""
318309
mock_self.project_dir = ""
319310
mock_self.project_dict = {"profile": "profile_name"}
320-
mock_yaml_parse.return_value = profiles_dict
311+
mock_self.yaml.safe_load.return_value = profiles_dict
312+
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
321313
with self.assertRaises(NotImplementedError):
322314
DbtParser.set_connection(mock_self)
323315

324316
mock_open_file.assert_called_once_with(PROFILES_FILE)
325-
mock_yaml_parse.assert_called_once_with(mock_open_file())
317+
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
326318
self.assertNotIsInstance(mock_self.connection, dict)
327319

328320

0 commit comments

Comments
 (0)