@@ -337,257 +337,3 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
337337
338338def _diff_output_base (dev_path : str , prod_path : str ) -> str :
339339 return f"[green]{ prod_path } <> { dev_path } [/] \n "
340-
341-
342- class DbtParser :
343- def __init__ (self , profiles_dir_override : str , project_dir_override : str ) -> None :
344- self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
345- self .profiles_dir = Path (profiles_dir_override or default_profiles_dir ())
346- self .project_dir = Path (project_dir_override or default_project_dir ())
347- self .connection = None
348- self .project_dict = self .get_project_dict ()
349- self .manifest_obj = self .get_manifest_obj ()
350- self .dbt_user_id = self .manifest_obj .metadata .user_id
351- self .dbt_version = self .manifest_obj .metadata .dbt_version
352- self .dbt_project_id = self .manifest_obj .metadata .project_id
353- self .requires_upper = False
354- self .threads = None
355- self .unique_columns = self .get_unique_columns ()
356-
357- def get_datadiff_variables (self ) -> dict :
358- vars = get_from_dict_with_raise (self .project_dict , "vars" , f"No vars: found in dbt_project.yml." )
359- return get_from_dict_with_raise (vars , "data_diff" , f"data_diff: section not found in dbt_project.yml vars:." )
360-
361- def get_models (self ):
362- with open (self .project_dir / RUN_RESULTS_PATH ) as run_results :
363- run_results_dict = json .load (run_results )
364- run_results_obj = self .parse_run_results (run_results = run_results_dict )
365-
366- dbt_version = parse_version (run_results_obj .metadata .dbt_version )
367-
368- if dbt_version < parse_version ("1.3.0" ):
369- self .profiles_dir = legacy_profiles_dir ()
370-
371- if dbt_version < parse_version (LOWER_DBT_V ) or dbt_version >= parse_version (UPPER_DBT_V ):
372- raise Exception (
373- f"Found dbt: v{ dbt_version } Expected the dbt project's version to be >= { LOWER_DBT_V } and < { UPPER_DBT_V } "
374- )
375-
376- success_models = [x .unique_id for x in run_results_obj .results if x .status .name == "success" ]
377- models = [self .manifest_obj .nodes .get (x ) for x in success_models ]
378- if not models :
379- raise ValueError ("Expected > 0 successful models runs from the last dbt command." )
380-
381- print (f"Running with data-diff={ __version__ } \n " )
382- return models
383-
384- def get_manifest_obj (self ):
385- with open (self .project_dir / MANIFEST_PATH ) as manifest :
386- manifest_dict = json .load (manifest )
387- manifest_obj = self .parse_manifest (manifest = manifest_dict )
388- return manifest_obj
389-
390- def get_project_dict (self ):
391- with open (self .project_dir / PROJECT_FILE ) as project :
392- project_dict = self .yaml .safe_load (project )
393- return project_dict
394-
395- def _get_connection_creds (self ) -> Tuple [Dict [str , str ], str ]:
396- profiles_path = self .profiles_dir / PROFILES_FILE
397- with open (profiles_path ) as profiles :
398- profiles = self .yaml .safe_load (profiles )
399-
400- dbt_profile_var = self .project_dict .get ("profile" )
401-
402- profile = get_from_dict_with_raise (
403- profiles , dbt_profile_var , f"No profile '{ dbt_profile_var } ' found in '{ profiles_path } '."
404- )
405- # values can contain env_vars
406- rendered_profile = self .ProfileRenderer ().render_data (profile )
407- profile_target = get_from_dict_with_raise (
408- rendered_profile , "target" , f"No target found in profile '{ dbt_profile_var } ' in '{ profiles_path } '."
409- )
410- outputs = get_from_dict_with_raise (
411- rendered_profile , "outputs" , f"No outputs found in profile '{ dbt_profile_var } ' in '{ profiles_path } '."
412- )
413- credentials = get_from_dict_with_raise (
414- outputs ,
415- profile_target ,
416- f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile_var } ' in '{ profiles_path } '." ,
417- )
418- conn_type = get_from_dict_with_raise (
419- credentials ,
420- "type" ,
421- f"No type found for target '{ profile_target } ' in profile '{ dbt_profile_var } ' in '{ profiles_path } '." ,
422- )
423- conn_type = conn_type .lower ()
424-
425- return credentials , conn_type
426-
427- def set_connection (self ):
428- credentials , conn_type = self ._get_connection_creds ()
429-
430- if conn_type == "snowflake" :
431- conn_info = {
432- "driver" : conn_type ,
433- "user" : credentials .get ("user" ),
434- "account" : credentials .get ("account" ),
435- "database" : credentials .get ("database" ),
436- "warehouse" : credentials .get ("warehouse" ),
437- "role" : credentials .get ("role" ),
438- "schema" : credentials .get ("schema" ),
439- }
440- self .threads = credentials .get ("threads" )
441- self .requires_upper = True
442-
443- if credentials .get ("private_key_path" ) is not None :
444- if credentials .get ("password" ) is not None :
445- raise Exception ("Cannot use password and key at the same time" )
446- conn_info ["key" ] = credentials .get ("private_key_path" )
447- conn_info ["private_key_passphrase" ] = credentials .get ("private_key_passphrase" )
448- elif credentials .get ("authenticator" ) is not None :
449- conn_info ["authenticator" ] = credentials .get ("authenticator" )
450- conn_info ["password" ] = credentials .get ("password" )
451- elif credentials .get ("password" ) is not None :
452- conn_info ["password" ] = credentials .get ("password" )
453- else :
454- raise Exception ("Snowflake: unsupported auth method" )
455- elif conn_type == "bigquery" :
456- method = credentials .get ("method" )
457- # there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
458- # this assumes that the user is auth'd via `gcloud auth application-default login`
459- if method is None or method != "oauth" :
460- raise Exception ("Oauth is the current method supported for Big Query." )
461- conn_info = {
462- "driver" : conn_type ,
463- "project" : credentials .get ("project" ),
464- "dataset" : credentials .get ("dataset" ),
465- }
466- self .threads = credentials .get ("threads" )
467- elif conn_type == "duckdb" :
468- conn_info = {
469- "driver" : conn_type ,
470- "filepath" : credentials .get ("path" ),
471- }
472- elif conn_type == "redshift" :
473- if (credentials .get ("pass" ) is None and credentials .get ("password" ) is None ) or credentials .get (
474- "method"
475- ) == "iam" :
476- raise Exception ("Only password authentication is currently supported for Redshift." )
477- conn_info = {
478- "driver" : conn_type ,
479- "host" : credentials .get ("host" ),
480- "user" : credentials .get ("user" ),
481- "password" : credentials .get ("password" ) or credentials .get ("pass" ),
482- "port" : credentials .get ("port" ),
483- "dbname" : credentials .get ("dbname" ),
484- }
485- self .threads = credentials .get ("threads" )
486- elif conn_type == "databricks" :
487- conn_info = {
488- "driver" : conn_type ,
489- "catalog" : credentials .get ("catalog" ),
490- "server_hostname" : credentials .get ("host" ),
491- "http_path" : credentials .get ("http_path" ),
492- "schema" : credentials .get ("schema" ),
493- "access_token" : credentials .get ("token" ),
494- }
495- self .threads = credentials .get ("threads" )
496- elif conn_type == "postgres" :
497- conn_info = {
498- "driver" : "postgresql" ,
499- "host" : credentials .get ("host" ),
500- "user" : credentials .get ("user" ),
501- "password" : credentials .get ("password" ),
502- "port" : credentials .get ("port" ),
503- "dbname" : credentials .get ("dbname" ) or credentials .get ("database" ),
504- }
505- self .threads = credentials .get ("threads" )
506- else :
507- raise NotImplementedError (f"Provider { conn_type } is not yet supported for dbt diffs" )
508-
509- self .connection = conn_info
510-
511- def get_pk_from_model (self , node , unique_columns : dict , pk_tag : str ) -> List [str ]:
512- try :
513- # Get a set of all the column names
514- column_names = {name for name , params in node .columns .items ()}
515- # Check if the tag is present on a table level
516- if pk_tag in node .meta :
517- # Get all the PKs that are also present as a column
518- pks = [pk for pk in pk_tag in node .meta [pk_tag ] if pk in column_names ]
519- if pks :
520- # If there are any left, return it
521- logger .debug ("Found PKs via Table META: " + str (pks ))
522- return pks
523-
524- from_meta = [name for name , params in node .columns .items () if pk_tag in params .meta ] or None
525- if from_meta :
526- logger .debug ("Found PKs via META: " + str (from_meta ))
527- return from_meta
528-
529- from_tags = [name for name , params in node .columns .items () if pk_tag in params .tags ] or None
530- if from_tags :
531- logger .debug ("Found PKs via Tags: " + str (from_tags ))
532- return from_tags
533-
534- if node .unique_id in unique_columns :
535- from_uniq = unique_columns .get (node .unique_id )
536- if from_uniq is not None :
537- logger .debug ("Found PKs via Uniqueness tests: " + str (from_uniq ))
538- return list (from_uniq )
539-
540- except (KeyError , IndexError , TypeError ) as e :
541- raise e
542-
543- logger .debug ("Found no PKs" )
544- return []
545-
546- def get_unique_columns (self ) -> Dict [str , Set [str ]]:
547- manifest = self .manifest_obj
548- cols_by_uid = defaultdict (set )
549- for node in manifest .nodes .values ():
550- try :
551- if not (node .resource_type .value == "test" and hasattr (node , "test_metadata" )):
552- continue
553-
554- if not node .depends_on or not node .depends_on .nodes :
555- continue
556-
557- uid = node .depends_on .nodes [0 ]
558-
559- # sources can have tests and are not in manifest.nodes
560- # skip as source unique columns are not needed
561- if uid .startswith ("source." ):
562- continue
563-
564- model_node = manifest .nodes [uid ]
565-
566- if node .test_metadata .name == "unique" :
567- column_name : str = node .test_metadata .kwargs ["column_name" ]
568- for col in self ._parse_concat_pk_definition (column_name ):
569- if model_node is None or col in model_node .columns :
570- # skip anything that is not a column.
571- # for example, string literals used in concat
572- # like "pk1 || '-' || pk2"
573- cols_by_uid [uid ].add (col )
574-
575- if node .test_metadata .name == "unique_combination_of_columns" :
576- for col in node .test_metadata .kwargs ["combination_of_columns" ]:
577- cols_by_uid [uid ].add (col )
578-
579- except (KeyError , IndexError , TypeError ) as e :
580- logger .warning ("Failure while finding unique cols: %s" , e )
581-
582- return cols_by_uid
583-
584- def _parse_concat_pk_definition (self , definition : str ) -> List [str ]:
585- definition = definition .strip ()
586- if definition .lower ().startswith ("concat(" ) and definition .endswith (")" ):
587- definition = definition [7 :- 1 ] # Removes concat( and )
588- columns = definition .split ("," )
589- else :
590- columns = definition .split ("||" )
591-
592- stripped_columns = [col .strip ('" ()' ) for col in columns ]
593- return stripped_columns
0 commit comments