Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ads/opctl/operator/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
from ads.opctl.operator import __operators__


class OperatorSchemaYamlError(Exception):
"""Exception raised when there is an issue with the schema."""
def __init__(self, error: str):
super().__init__(
"Invalid operator specification. Check the YAML structure and ensure it "
"complies with the required schema for the operator. \n"
f"{error}"
)


class OperatorNotFoundError(Exception):
def __init__(self, operator: str):
super().__init__(
Expand Down
11 changes: 4 additions & 7 deletions ads/opctl/operator/common/operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from ads.common.serializer import DataClassSerializable

from .utils import OperatorValidator
from ads.opctl.operator.common.utils import OperatorValidator
from ads.opctl.operator.common.errors import OperatorSchemaYamlError


@dataclass(repr=True)
Expand Down Expand Up @@ -52,19 +53,15 @@ def _validate_dict(cls, obj_dict: Dict) -> bool:

Raises
------
ValueError
ForecastSchemaYamlError
In case of wrong specification format.
"""
schema = cls._load_schema()
validator = OperatorValidator(schema)
result = validator.validate(obj_dict)

if not result:
raise ValueError(
"Invalid operator specification. Check the YAML structure and ensure it "
"complies with the required schema for the operator. \n"
f"{json.dumps(validator.errors, indent=2)}"
)
raise OperatorSchemaYamlError(json.dumps(validator.errors, indent=2))
return True

@classmethod
Expand Down
26 changes: 26 additions & 0 deletions ads/opctl/operator/lowcode/forecast/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

class ForecastSchemaYamlError(Exception):
"""Exception raised when there is an issue with the schema."""

def __init__(self, error: str):
super().__init__(
"Invalid forecast operator specification. Check the YAML structure and ensure it "
"complies with the required schema for forecast operator. \n"
f"{error}"
)


class ForecastInputDataError(Exception):
"""Exception raised when there is an issue with input data."""

def __init__(self, error: str):
super().__init__(
"Invalid input data. Check the input data and ensure it "
"complies with the validation criteria. \n"
f"{error}"
)
2 changes: 1 addition & 1 deletion ads/opctl/operator/lowcode/forecast/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ spec:

target_category_columns:
type: list
required: false
required: true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the code break without specifying this series column? If so we should open a ticket to change that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It breaks and the error message is not meaningful hence made this mandatory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahosler @prasankh let's open a ticket to fix this, as Allen suggested?

schema:
type: string

Expand Down
10 changes: 5 additions & 5 deletions ads/opctl/operator/lowcode/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ads.dataset.label_encoder import DataFrameLabelEncoder
from .const import SupportedModels, MAX_COLUMNS_AUTOMLX

from .errors import ForecastInputDataError, ForecastSchemaYamlError

def _label_encode_dataframe(df, no_encode=set()):
df_to_encode = df[list(set(df.columns) - no_encode)]
Expand Down Expand Up @@ -135,7 +135,7 @@ def _load_data(filename, format, storage_options, columns, **kwargs):
# keep only these columns, done after load because only CSV supports stream filtering
data = data[columns]
return data
raise ValueError(f"Unrecognized format: {format}")
raise ForecastInputDataError(f"Unrecognized format: {format}")


def _write_data(data, filename, format, storage_options, index=False, **kwargs):
Expand All @@ -147,7 +147,7 @@ def _write_data(data, filename, format, storage_options, index=False, **kwargs):
return _call_pandas_fsspec(
write_fn, filename, index=index, storage_options=storage_options
)
raise ValueError(f"Unrecognized format: {format}")
raise ForecastInputDataError(f"Unrecognized format: {format}")


def _merge_category_columns(data, target_category_columns):
Expand Down Expand Up @@ -178,7 +178,7 @@ def _clean_data(data, target_column, datetime_column, target_category_columns=No

return df.fillna(0), new_target_columns

raise ValueError(
raise ForecastSchemaYamlError(
f"Either target_columns, target_category_columns, or datetime_column not specified."
)

Expand Down Expand Up @@ -297,7 +297,7 @@ def _build_indexed_datasets(
new_target_columns = list(df_by_target.keys())
remaining_categories = set(unique_categories) - set(invalid_categories)
if not len(remaining_categories):
raise ValueError(
raise ForecastInputDataError(
"Stopping forecast operator as there is no data that meets the validation criteria."
)
return df_by_target, new_target_columns, remaining_categories
Expand Down
8 changes: 2 additions & 6 deletions ads/opctl/operator/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ads.common.extended_enum import ExtendedEnum
from ads.common.serializer import DataClassSerializable
from ads.opctl.operator.common.utils import _load_yaml_from_uri

from ads.opctl.operator.common.errors import OperatorSchemaYamlError

class OPERATOR_LOCAL_RUNTIME_TYPE(ExtendedEnum):
PYTHON = "python"
Expand Down Expand Up @@ -56,11 +56,7 @@ def _validate_dict(cls, obj_dict: Dict) -> bool:
result = validator.validate(obj_dict)

if not result:
raise ValueError(
"Invalid runtime specification. Check the YAML structure and ensure it "
"complies with the required schema for the runtime. \n"
f"{json.dumps(validator.errors, indent=2)}"
)
raise OperatorSchemaYamlError(json.dumps(validator.errors, indent=2))
return True


Expand Down