44from pathlib import Path
55from typing import List , Dict , Tuple , Set , Optional
66import yaml
7+ from pydantic import BaseModel
78
89from packaging .version import parse as parse_version
9- import pydantic
10- from dbt_artifacts_parser .parser import parse_run_results , parse_manifest
1110from dbt .config .renderer import ProfileRenderer
11+ from .dbt_config_validators import ManifestJsonConfig , RunResultsJsonConfig
1212
1313from data_diff .errors import (
1414 DataDiffDbtBigQueryUnsupportedMethodError ,
@@ -81,13 +81,13 @@ def legacy_profiles_dir() -> Path:
8181 return Path .home () / ".dbt"
8282
8383
84- class TDatadiffModelConfig (pydantic . BaseModel ):
84+ class TDatadiffModelConfig (BaseModel ):
8585 where_filter : Optional [str ] = None
8686 include_columns : List [str ] = []
8787 exclude_columns : List [str ] = []
8888
8989
90- class TDatadiffConfig (pydantic . BaseModel ):
90+ class TDatadiffConfig (BaseModel ):
9191 prod_database : Optional [str ] = None
9292 prod_schema : Optional [str ] = None
9393 prod_custom_schema : Optional [str ] = None
@@ -213,7 +213,6 @@ def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:
213213
214214 def get_simple_model_selection (self , dbt_selection : str ):
215215 model_nodes = dict (filter (lambda item : item [0 ].startswith ("model." ), self .dev_manifest_obj .nodes .items ()))
216-
217216 model_unique_key_list = [k for k , v in model_nodes .items () if v .name == dbt_selection ]
218217
219218 # name *should* always be unique, but just in case:
@@ -230,13 +229,13 @@ def get_simple_model_selection(self, dbt_selection: str):
230229
231230 return [model ]
232231
233- def get_run_results_models (self ):
232+ def get_run_results_models (self ) -> List [ ManifestJsonConfig . Nodes ] :
234233 with open (self .project_dir / RUN_RESULTS_PATH ) as run_results :
235234 logger .info (f"Parsing file { RUN_RESULTS_PATH } " )
236235 run_results_dict = json .load (run_results )
237- run_results_obj = parse_run_results ( run_results = run_results_dict )
236+ run_results_validated = RunResultsJsonConfig . parse_obj ( run_results_dict )
238237
239- dbt_version = parse_version (run_results_obj .metadata .dbt_version )
238+ dbt_version = parse_version (run_results_validated .metadata .dbt_version )
240239
241240 if dbt_version < parse_version (LOWER_DBT_V ):
242241 raise DataDiffDbtRunResultsVersionError (
@@ -247,7 +246,8 @@ def get_run_results_models(self):
247246 f"{ dbt_version } is a recent version of dbt and may not be fully tested with data-diff! \n Please report any issues to https://github.com/datafold/data-diff/issues"
248247 )
249248
250- success_models = [x .unique_id for x in run_results_obj .results if x .status .name == "success" ]
249+ success_models = [x .unique_id for x in run_results_validated .results if x .status == x .Status .success ]
250+
251251 models = [self .dev_manifest_obj .nodes .get (x ) for x in success_models ]
252252 if not models :
253253 raise DataDiffDbtNoSuccessfulModelsInRunError (
@@ -256,11 +256,11 @@ def get_run_results_models(self):
256256
257257 return models
258258
259- def get_manifest_obj (self , path : Path ):
259+ def get_manifest_obj (self , path : Path ) -> ManifestJsonConfig :
260260 with open (path ) as manifest :
261261 logger .info (f"Parsing file { path } " )
262262 manifest_dict = json .load (manifest )
263- manifest_obj = parse_manifest ( manifest = manifest_dict )
263+ manifest_obj = ManifestJsonConfig . parse_obj ( manifest_dict )
264264 return manifest_obj
265265
266266 def get_project_dict (self ):
@@ -433,7 +433,6 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str
433433 if from_tags :
434434 logger .debug ("Found PKs via Tags: " + str (from_tags ))
435435 return from_tags
436-
437436 if node .unique_id in unique_columns :
438437 from_uniq = unique_columns .get (node .unique_id )
439438 if from_uniq is not None :
@@ -451,7 +450,7 @@ def get_unique_columns(self) -> Dict[str, Set[str]]:
451450 cols_by_uid = defaultdict (set )
452451 for node in manifest .nodes .values ():
453452 try :
454- if not (node .resource_type . value == "test" and hasattr (node , "test_metadata" )):
453+ if not (node .resource_type == "test" and hasattr (node , "test_metadata" )):
455454 continue
456455
457456 if not node .depends_on or not node .depends_on .nodes :
@@ -465,7 +464,6 @@ def get_unique_columns(self) -> Dict[str, Set[str]]:
465464 continue
466465
467466 model_node = manifest .nodes [uid ]
468-
469467 if node .test_metadata .name == "unique" :
470468 column_name : str = node .test_metadata .kwargs ["column_name" ]
471469 for col in self ._parse_concat_pk_definition (column_name ):
0 commit comments