Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions openlayer/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def validate(self) -> List[str]:
if not self.failed_validations:
self._validate_bundle_resources()

if not self.failed_validations:
self._validate_resource_consistency()

if not self.failed_validations:
logger.info(
"----------------------------------------------------------------------------\n"
Expand Down Expand Up @@ -401,6 +404,137 @@ def _load_model_config_from_bundle(self) -> Dict[str, Any]:

return model_config

def _validate_resource_consistency(self):
"""Validates that the resources in the bundle are consistent with each other.

For example, if the `classNames` field on the dataset configs are consistent
with the one on the model config.
"""
resource_consistency_failed_validations = []

if (
"training" in self._bundle_resources
and "validation" in self._bundle_resources
):
# Loading the relevant configs
model_config = {}
if "model" in self._bundle_resources:
model_config = self._load_model_config_from_bundle()
training_dataset_config = self._load_dataset_config_from_bundle("training")
validation_dataset_config = self._load_dataset_config_from_bundle(
"validation"
)
model_feature_names = model_config.get("featureNames")
model_class_names = model_config.get("classNames")
training_feature_names = training_dataset_config.get("featureNames")
training_class_names = training_dataset_config.get("classNames")
validation_feature_names = validation_dataset_config.get("featureNames")
validation_class_names = validation_dataset_config.get("classNames")

# Validating the `featureNames` field
if training_feature_names or validation_feature_names:
if not self._feature_names_consistent(
model_feature_names=model_feature_names,
training_feature_names=training_feature_names,
validation_feature_names=validation_feature_names,
):
resource_consistency_failed_validations.append(
"The `featureNames` in the provided resources are inconsistent."
" The training and validation set feature names must have some overlap."
" Furthermore, if a model is provided, its feature names must be a subset"
" of the feature names in the training and validation sets."
)

# Validating the `classNames` field
if not self._class_names_consistent(
model_class_names=model_class_names,
training_class_names=training_class_names,
validation_class_names=validation_class_names,
):
resource_consistency_failed_validations.append(
"The `classNames` in the provided resources are inconsistent."
" The validation set's class names need to contain the training set's."
" Furthermore, if a model is provided, its class names must be contained"
" in the training and validation sets' class names."
" Note that the order of the items in the `classNames` list matters."
)

# Print results of the validation
if resource_consistency_failed_validations:
logger.error("Bundle resource consistency failed validations:")
_list_failed_validation_messages(resource_consistency_failed_validations)

# Add the bundle resource consistency failed validations to the list of all failed validations
self.failed_validations.extend(resource_consistency_failed_validations)

@staticmethod
def _feature_names_consistent(
model_feature_names: Optional[List[str]],
training_feature_names: List[str],
validation_feature_names: List[str],
) -> bool:
"""Checks whether the feature names in the training, validation and model
configs are consistent.

Parameters
----------
model_feature_names : List[str]
The feature names in the model config.
training_feature_names : List[str]
The feature names in the training dataset config.
validation_feature_names : List[str]
The feature names in the validation dataset config.

Returns
-------
bool
True if the feature names are consistent, False otherwise.
"""
train_val_intersection = set(training_feature_names).intersection(
set(validation_feature_names)
)
if model_feature_names is None:
return len(train_val_intersection) != 0
return set(model_feature_names).issubset(train_val_intersection)

@staticmethod
def _class_names_consistent(
model_class_names: Optional[List[str]],
training_class_names: List[str],
validation_class_names: List[str],
) -> bool:
"""Checks whether the class names in the training and model configs
are consistent.

Parameters
----------
model_class_names : List[str]
The class names in the model config.
training_class_names : List[str]
The class names in the training dataset config.
validation_class_names : List[str]
The class names in the validation dataset config.

Returns
-------
bool
True if the class names are consistent, False otherwise.
"""
if model_class_names is not None:
num_model_classes = len(model_class_names)
try:
return (
training_class_names[:num_model_classes] == model_class_names
and validation_class_names[:num_model_classes] == model_class_names
)
except IndexError:
return False
num_training_classes = len(training_class_names)
try:
return validation_class_names[:num_training_classes] == training_class_names
except IndexError:
return False


class CommitValidator:
"""Validates the commit prior to the upload.
Expand Down