Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e391403

Browse files
committed
update after rebasing to the latest master
1 parent dd2330f commit e391403

File tree

2 files changed

+4
-256
lines changed

2 files changed

+4
-256
lines changed

data_diff/dbt.py

Lines changed: 0 additions & 254 deletions
Original file line numberDiff line numberDiff line change
@@ -337,257 +337,3 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
337337

338338
def _diff_output_base(dev_path: str, prod_path: str) -> str:
339339
return f"[green]{prod_path} <> {dev_path}[/] \n"
340-
341-
342-
class DbtParser:
343-
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
344-
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
345-
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
346-
self.project_dir = Path(project_dir_override or default_project_dir())
347-
self.connection = None
348-
self.project_dict = self.get_project_dict()
349-
self.manifest_obj = self.get_manifest_obj()
350-
self.dbt_user_id = self.manifest_obj.metadata.user_id
351-
self.dbt_version = self.manifest_obj.metadata.dbt_version
352-
self.dbt_project_id = self.manifest_obj.metadata.project_id
353-
self.requires_upper = False
354-
self.threads = None
355-
self.unique_columns = self.get_unique_columns()
356-
357-
def get_datadiff_variables(self) -> dict:
358-
vars = get_from_dict_with_raise(self.project_dict, "vars", f"No vars: found in dbt_project.yml.")
359-
return get_from_dict_with_raise(vars, "data_diff", f"data_diff: section not found in dbt_project.yml vars:.")
360-
361-
def get_models(self):
362-
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
363-
run_results_dict = json.load(run_results)
364-
run_results_obj = self.parse_run_results(run_results=run_results_dict)
365-
366-
dbt_version = parse_version(run_results_obj.metadata.dbt_version)
367-
368-
if dbt_version < parse_version("1.3.0"):
369-
self.profiles_dir = legacy_profiles_dir()
370-
371-
if dbt_version < parse_version(LOWER_DBT_V) or dbt_version >= parse_version(UPPER_DBT_V):
372-
raise Exception(
373-
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}"
374-
)
375-
376-
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
377-
models = [self.manifest_obj.nodes.get(x) for x in success_models]
378-
if not models:
379-
raise ValueError("Expected > 0 successful models runs from the last dbt command.")
380-
381-
print(f"Running with data-diff={__version__}\n")
382-
return models
383-
384-
def get_manifest_obj(self):
385-
with open(self.project_dir / MANIFEST_PATH) as manifest:
386-
manifest_dict = json.load(manifest)
387-
manifest_obj = self.parse_manifest(manifest=manifest_dict)
388-
return manifest_obj
389-
390-
def get_project_dict(self):
391-
with open(self.project_dir / PROJECT_FILE) as project:
392-
project_dict = self.yaml.safe_load(project)
393-
return project_dict
394-
395-
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
396-
profiles_path = self.profiles_dir / PROFILES_FILE
397-
with open(profiles_path) as profiles:
398-
profiles = self.yaml.safe_load(profiles)
399-
400-
dbt_profile_var = self.project_dict.get("profile")
401-
402-
profile = get_from_dict_with_raise(
403-
profiles, dbt_profile_var, f"No profile '{dbt_profile_var}' found in '{profiles_path}'."
404-
)
405-
# values can contain env_vars
406-
rendered_profile = self.ProfileRenderer().render_data(profile)
407-
profile_target = get_from_dict_with_raise(
408-
rendered_profile, "target", f"No target found in profile '{dbt_profile_var}' in '{profiles_path}'."
409-
)
410-
outputs = get_from_dict_with_raise(
411-
rendered_profile, "outputs", f"No outputs found in profile '{dbt_profile_var}' in '{profiles_path}'."
412-
)
413-
credentials = get_from_dict_with_raise(
414-
outputs,
415-
profile_target,
416-
f"No credentials found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
417-
)
418-
conn_type = get_from_dict_with_raise(
419-
credentials,
420-
"type",
421-
f"No type found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
422-
)
423-
conn_type = conn_type.lower()
424-
425-
return credentials, conn_type
426-
427-
def set_connection(self):
428-
credentials, conn_type = self._get_connection_creds()
429-
430-
if conn_type == "snowflake":
431-
conn_info = {
432-
"driver": conn_type,
433-
"user": credentials.get("user"),
434-
"account": credentials.get("account"),
435-
"database": credentials.get("database"),
436-
"warehouse": credentials.get("warehouse"),
437-
"role": credentials.get("role"),
438-
"schema": credentials.get("schema"),
439-
}
440-
self.threads = credentials.get("threads")
441-
self.requires_upper = True
442-
443-
if credentials.get("private_key_path") is not None:
444-
if credentials.get("password") is not None:
445-
raise Exception("Cannot use password and key at the same time")
446-
conn_info["key"] = credentials.get("private_key_path")
447-
conn_info["private_key_passphrase"] = credentials.get("private_key_passphrase")
448-
elif credentials.get("authenticator") is not None:
449-
conn_info["authenticator"] = credentials.get("authenticator")
450-
conn_info["password"] = credentials.get("password")
451-
elif credentials.get("password") is not None:
452-
conn_info["password"] = credentials.get("password")
453-
else:
454-
raise Exception("Snowflake: unsupported auth method")
455-
elif conn_type == "bigquery":
456-
method = credentials.get("method")
457-
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
458-
# this assumes that the user is auth'd via `gcloud auth application-default login`
459-
if method is None or method != "oauth":
460-
raise Exception("Oauth is the current method supported for Big Query.")
461-
conn_info = {
462-
"driver": conn_type,
463-
"project": credentials.get("project"),
464-
"dataset": credentials.get("dataset"),
465-
}
466-
self.threads = credentials.get("threads")
467-
elif conn_type == "duckdb":
468-
conn_info = {
469-
"driver": conn_type,
470-
"filepath": credentials.get("path"),
471-
}
472-
elif conn_type == "redshift":
473-
if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(
474-
"method"
475-
) == "iam":
476-
raise Exception("Only password authentication is currently supported for Redshift.")
477-
conn_info = {
478-
"driver": conn_type,
479-
"host": credentials.get("host"),
480-
"user": credentials.get("user"),
481-
"password": credentials.get("password") or credentials.get("pass"),
482-
"port": credentials.get("port"),
483-
"dbname": credentials.get("dbname"),
484-
}
485-
self.threads = credentials.get("threads")
486-
elif conn_type == "databricks":
487-
conn_info = {
488-
"driver": conn_type,
489-
"catalog": credentials.get("catalog"),
490-
"server_hostname": credentials.get("host"),
491-
"http_path": credentials.get("http_path"),
492-
"schema": credentials.get("schema"),
493-
"access_token": credentials.get("token"),
494-
}
495-
self.threads = credentials.get("threads")
496-
elif conn_type == "postgres":
497-
conn_info = {
498-
"driver": "postgresql",
499-
"host": credentials.get("host"),
500-
"user": credentials.get("user"),
501-
"password": credentials.get("password"),
502-
"port": credentials.get("port"),
503-
"dbname": credentials.get("dbname") or credentials.get("database"),
504-
}
505-
self.threads = credentials.get("threads")
506-
else:
507-
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
508-
509-
self.connection = conn_info
510-
511-
def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str]:
512-
try:
513-
# Get a set of all the column names
514-
column_names = {name for name, params in node.columns.items()}
515-
# Check if the tag is present on a table level
516-
if pk_tag in node.meta:
517-
# Get all the PKs that are also present as a column
518-
pks = [pk for pk in pk_tag in node.meta[pk_tag] if pk in column_names]
519-
if pks:
520-
# If there are any left, return it
521-
logger.debug("Found PKs via Table META: " + str(pks))
522-
return pks
523-
524-
from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None
525-
if from_meta:
526-
logger.debug("Found PKs via META: " + str(from_meta))
527-
return from_meta
528-
529-
from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None
530-
if from_tags:
531-
logger.debug("Found PKs via Tags: " + str(from_tags))
532-
return from_tags
533-
534-
if node.unique_id in unique_columns:
535-
from_uniq = unique_columns.get(node.unique_id)
536-
if from_uniq is not None:
537-
logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq))
538-
return list(from_uniq)
539-
540-
except (KeyError, IndexError, TypeError) as e:
541-
raise e
542-
543-
logger.debug("Found no PKs")
544-
return []
545-
546-
def get_unique_columns(self) -> Dict[str, Set[str]]:
547-
manifest = self.manifest_obj
548-
cols_by_uid = defaultdict(set)
549-
for node in manifest.nodes.values():
550-
try:
551-
if not (node.resource_type.value == "test" and hasattr(node, "test_metadata")):
552-
continue
553-
554-
if not node.depends_on or not node.depends_on.nodes:
555-
continue
556-
557-
uid = node.depends_on.nodes[0]
558-
559-
# sources can have tests and are not in manifest.nodes
560-
# skip as source unique columns are not needed
561-
if uid.startswith("source."):
562-
continue
563-
564-
model_node = manifest.nodes[uid]
565-
566-
if node.test_metadata.name == "unique":
567-
column_name: str = node.test_metadata.kwargs["column_name"]
568-
for col in self._parse_concat_pk_definition(column_name):
569-
if model_node is None or col in model_node.columns:
570-
# skip anything that is not a column.
571-
# for example, string literals used in concat
572-
# like "pk1 || '-' || pk2"
573-
cols_by_uid[uid].add(col)
574-
575-
if node.test_metadata.name == "unique_combination_of_columns":
576-
for col in node.test_metadata.kwargs["combination_of_columns"]:
577-
cols_by_uid[uid].add(col)
578-
579-
except (KeyError, IndexError, TypeError) as e:
580-
logger.warning("Failure while finding unique cols: %s", e)
581-
582-
return cols_by_uid
583-
584-
def _parse_concat_pk_definition(self, definition: str) -> List[str]:
585-
definition = definition.strip()
586-
if definition.lower().startswith("concat(") and definition.endswith(")"):
587-
definition = definition[7:-1] # Removes concat( and )
588-
columns = definition.split(",")
589-
else:
590-
columns = definition.split("||")
591-
592-
stripped_columns = [col.strip('" ()') for col in columns]
593-
return stripped_columns

data_diff/dbt_parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,15 @@ def set_connection(self):
178178
"filepath": credentials.get("path"),
179179
}
180180
elif conn_type == "redshift":
181-
if credentials.get("password") is None or credentials.get("method") == "iam":
181+
if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(
182+
"method"
183+
) == "iam":
182184
raise Exception("Only password authentication is currently supported for Redshift.")
183185
conn_info = {
184186
"driver": conn_type,
185187
"host": credentials.get("host"),
186188
"user": credentials.get("user"),
187-
"password": credentials.get("password"),
189+
"password": credentials.get("password") or credentials.get("pass"),
188190
"port": credentials.get("port"),
189191
"dbname": credentials.get("dbname"),
190192
}

0 commit comments

Comments
 (0)