Skip to content

Commit 5435781

Browse files
Completes OPEN-3366 Refactor dataset validations
1 parent cede997 commit 5435781

File tree

8 files changed

+360
-202
lines changed

8 files changed

+360
-202
lines changed

openlayer/__init__.py

Lines changed: 42 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import yaml
1111

1212
from . import api, exceptions, schemas, utils, validators
13-
from .datasets import Dataset
13+
from .datasets import Dataset, DatasetType
1414
from .models import Model
1515
from .projects import Project
1616
from .tasks import TaskType
@@ -274,8 +274,8 @@ def add_model(
274274

275275
if failed_validations:
276276
raise exceptions.OpenlayerValidationError(
277-
context="There are issues with the model package, as specified above. \n",
278-
mitigation="Make sure to fix all of them before uploading the model.",
277+
"There are issues with the model package. \n"
278+
"Make sure to fix all of the issues listed above before the upload.",
279279
) from None
280280

281281
# ------ Start of temporary workaround for the arguments in the payload ------ #
@@ -307,6 +307,7 @@ def add_model(
307307
utils.remove_python_version(model_package_dir)
308308

309309
# Make sure the resulting model package is less than 2 GB
310+
# TODO: this should depend on the subscription plan
310311
if float(os.path.getsize("model")) / 1e9 > 2:
311312
raise exceptions.OpenlayerResourceError(
312313
context="There's an issue with the specified `model_package_dir`. \n",
@@ -342,13 +343,15 @@ def add_dataset(
342343
file_path: str,
343344
class_names: List[str],
344345
label_column_name: str,
346+
dataset_type: DatasetType,
345347
feature_names: List[str] = [],
346348
text_column_name: Optional[str] = None,
347349
categorical_feature_names: List[str] = [],
348350
tag_column_name: Optional[str] = None,
349351
language: str = "en",
350352
sep: str = ",",
351353
commit_message: Optional[str] = None,
354+
dataset_config_file_path: Optional[str] = None,
352355
project_id: str = None,
353356
) -> Dataset:
354357
r"""Uploads a dataset to the Openlayer platform (from a csv).
@@ -365,6 +368,9 @@ def add_dataset(
365368
366369
.. important::
367370
The labels in this column must be zero-indexed integer values.
371+
dataset_type : :obj:`DatasetType`
372+
Type of dataset. E.g. :obj:`DatasetType.Validation` or
373+
:obj:`DatasetType.Training`.
368374
feature_names : List[str], default []
369375
List of input feature names. Only applicable if your ``task_type`` is
370376
:obj:`TaskType.TabularClassification` or :obj:`TaskType.TabularRegression`.
@@ -488,154 +494,36 @@ def add_dataset(
488494
... )
489495
>>> dataset.to_dict()
490496
"""
491-
# ---------------------------- Schema validations ---------------------------- #
492-
if task_type not in [
493-
TaskType.TabularClassification,
494-
TaskType.TextClassification,
495-
]:
496-
raise exceptions.OpenlayerValidationError(
497-
"`task_type` must be either TaskType.TabularClassification or "
498-
"TaskType.TextClassification. \n"
499-
) from None
500-
dataset_schema = schemas.DatasetSchema()
501-
try:
502-
dataset_schema.load(
503-
{
504-
"file_path": file_path,
505-
"commit_message": commit_message,
506-
"class_names": class_names,
507-
"label_column_name": label_column_name,
508-
"tag_column_name": tag_column_name,
509-
"language": language,
510-
"sep": sep,
511-
"feature_names": feature_names,
512-
"text_column_name": text_column_name,
513-
"categorical_feature_names": categorical_feature_names,
514-
}
515-
)
516-
except ma.ValidationError as err:
497+
# ---------------------------- Dataset validations --------------------------- #
498+
# TODO: re-think the way the arguments are passed for the dataset upload
499+
dataset_config = None
500+
if dataset_config_file_path is None:
501+
dataset_config = {
502+
"file_path": file_path,
503+
"class_names": class_names,
504+
"label_column_name": label_column_name,
505+
"dataset_type": dataset_type.value,
506+
"feature_names": feature_names,
507+
"text_column_name": text_column_name,
508+
"categorical_feature_names": categorical_feature_names,
509+
"language": language,
510+
"sep": sep,
511+
}
512+
513+
dataset_validator = validators.DatasetValidator(
514+
dataset_config_file_path=dataset_config_file_path,
515+
dataset_config=dataset_config,
516+
dataset_file_path=file_path,
517+
)
518+
failed_validations = dataset_validator.validate()
519+
520+
if failed_validations:
517521
raise exceptions.OpenlayerValidationError(
518-
self._format_error_message(err)
522+
"There are issues with the dataset and its config. \n"
523+
"Make sure to fix all of the issues listed above before the upload.",
519524
) from None
520525

521-
# --------------------------- Resource validations --------------------------- #
522-
exp_file_path = os.path.expanduser(file_path)
523526
object_name = "original.csv"
524-
if not os.path.isfile(exp_file_path):
525-
raise exceptions.OpenlayerResourceError(
526-
f"File at path `{file_path}` does not contain the dataset. \n"
527-
) from None
528-
529-
with open(exp_file_path, "rt") as f:
530-
reader = csv.reader(f, delimiter=sep)
531-
headers = next(reader)
532-
row_count = sum(1 for _ in reader)
533-
534-
df = pd.read_csv(file_path, sep=sep)
535-
536-
# Checking for null values
537-
if df.isnull().values.any():
538-
raise exceptions.OpenlayerResourceError(
539-
context="There's an issue with the specified dataset. \n",
540-
message="The dataset contains null values, which is currently "
541-
"not supported. \n",
542-
mitigation="Make sure to upload a dataset without null values.",
543-
) from None
544-
545-
# Validating if the labels are zero indexed ints
546-
unique_labels = set(df[label_column_name].unique())
547-
zero_indexed_set = set(range(len(class_names)))
548-
if unique_labels != zero_indexed_set:
549-
raise exceptions.OpenlayerResourceError(
550-
context=f"There's an issue with values in the column "
551-
f"`{label_column_name}` of the dataset. \n",
552-
message=f"The labels in `{label_column_name}` must be "
553-
"zero-indexed integer values. \n",
554-
mitigation="Make sure to upload a dataset with zero-indexed "
555-
"integer labels that match the list in `class_names`. "
556-
f"For example, the class `{class_names[0]}` should be "
557-
"represented as a 0 in the dataset, the class "
558-
f"`{class_names[1]}` should be a 1, and so on.",
559-
) from None
560-
561-
# Validating the column dtypes
562-
supported_dtypes = {"float32", "float64", "int32", "int64", "object"}
563-
error_msg = ""
564-
for col in df:
565-
dtype = df[col].dtype.name
566-
if dtype not in supported_dtypes:
567-
error_msg += f"- Column `{col}` is of dtype {dtype}. \n"
568-
if error_msg:
569-
raise exceptions.OpenlayerResourceError(
570-
context="There is an issue with some of the columns dtypes.\n",
571-
message=error_msg,
572-
mitigation=f"The supported dtypes are {supported_dtypes}. "
573-
"Make sure to cast the above columns to a supported dtype.",
574-
) from None
575-
# ------------------ Resource-schema consistency validations ----------------- #
576-
# Label column validations
577-
try:
578-
headers.index(label_column_name)
579-
except ValueError:
580-
raise exceptions.OpenlayerDatasetInconsistencyError(
581-
f"`{label_column_name}` specified as `label_column_name` is not "
582-
"in the dataset. \n"
583-
) from None
584-
585-
if len(unique_labels) > len(class_names):
586-
raise exceptions.OpenlayerDatasetInconsistencyError(
587-
f"There are {len(unique_labels)} classes represented in the dataset, "
588-
f"but only {len(class_names)} items in your `class_names`. \n",
589-
mitigation=f"Make sure that there are at most {len(class_names)} "
590-
"classes in your dataset.",
591-
) from None
592-
593-
# Feature validations
594-
try:
595-
if text_column_name:
596-
feature_names = [text_column_name]
597-
for feature_name in feature_names:
598-
headers.index(feature_name)
599-
except ValueError:
600-
if text_column_name:
601-
raise exceptions.OpenlayerDatasetInconsistencyError(
602-
f"`{text_column_name}` specified as `text_column_name` is not in "
603-
"the dataset. \n"
604-
) from None
605-
else:
606-
features_not_in_dataset = [
607-
feature for feature in feature_names if feature not in headers
608-
]
609-
raise exceptions.OpenlayerDatasetInconsistencyError(
610-
f"Features {features_not_in_dataset} specified in `feature_names` "
611-
"are not in the dataset. \n"
612-
) from None
613-
# Tag column validation
614-
try:
615-
if tag_column_name:
616-
headers.index(tag_column_name)
617-
except ValueError:
618-
raise exceptions.OpenlayerDatasetInconsistencyError(
619-
f"`{tag_column_name}` specified as `tag_column_name` is not in "
620-
"the dataset. \n"
621-
) from None
622-
623-
# ----------------------- Subscription plan validations ---------------------- #
624-
if row_count > self.subscription_plan["datasetRowCount"]:
625-
raise exceptions.OpenlayerSubscriptionPlanException(
626-
f"The dataset your are trying to upload contains {row_count} rows, "
627-
"which exceeds your plan's limit of "
628-
f"{self.subscription_plan['datasetRowCount']}. \n"
629-
) from None
630-
if task_type == TaskType.TextClassification:
631-
max_text_size = df[text_column_name].str.len().max()
632-
if max_text_size > 1000:
633-
raise exceptions.OpenlayerSubscriptionPlanException(
634-
"The dataset you are trying to upload contains rows with "
635-
f"{max_text_size} characters, which exceeds the 1000 character "
636-
"limit."
637-
) from None
638-
639527
endpoint = f"projects/{project_id}/datasets"
640528
payload = dict(
641529
commitMessage=commit_message,
@@ -666,13 +554,15 @@ def add_dataframe(
666554
df: pd.DataFrame,
667555
class_names: List[str],
668556
label_column_name: str,
557+
dataset_type: DatasetType,
669558
feature_names: List[str] = [],
670559
text_column_name: Optional[str] = None,
671560
categorical_feature_names: List[str] = [],
672561
commit_message: Optional[str] = None,
673562
tag_column_name: Optional[str] = None,
674563
language: str = "en",
675564
project_id: str = None,
565+
dataset_config_file_path: Optional[str] = None,
676566
) -> Dataset:
677567
r"""Uploads a dataset to the Openlayer platform (from a pandas DataFrame).
678568
@@ -688,6 +578,9 @@ def add_dataframe(
688578
689579
.. important::
690580
The labels in this column must be zero-indexed integer values.
581+
dataset_type : :obj:`DatasetType`
582+
Type of dataset. E.g. :obj:`DatasetType.Validation` or
583+
:obj:`DatasetType.Training`.
691584
feature_names : List[str], default []
692585
List of input feature names. Only applicable if your ``task_type`` is
693586
:obj:`TaskType.TabularClassification` or :obj:`TaskType.TabularRegression`.
@@ -820,13 +713,15 @@ def add_dataframe(
820713
task_type=task_type,
821714
class_names=class_names,
822715
label_column_name=label_column_name,
716+
dataset_type=dataset_type,
823717
text_column_name=text_column_name,
824718
commit_message=commit_message,
825719
tag_column_name=tag_column_name,
826720
language=language,
827721
feature_names=feature_names,
828722
categorical_feature_names=categorical_feature_names,
829723
project_id=project_id,
724+
dataset_config_file_path=dataset_config_file_path,
830725
)
831726

832727
@staticmethod

openlayer/datasets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
from enum import Enum
2+
3+
4+
class DatasetType(Enum):
5+
"""The different dataset types that are supported by Openlayer."""
6+
7+
#: For validation sets.
8+
Validation = "validation"
9+
#: For training sets.
10+
Training = "training"
11+
12+
113
class Dataset:
214
"""An object containing information about a dataset on the Openlayer platform."""
315

openlayer/exceptions.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,15 @@ def __init__(self, message, errcode=None):
2020
super().__init__(f"<Response> {message}")
2121

2222

23-
class OpenlayerResourceError(OpenlayerException):
24-
def __init__(self, message, context=None, mitigation=None):
25-
if not context:
26-
context = "There is a problem with the specified file path. \n"
27-
if not mitigation:
28-
mitigation = (
29-
"Make sure that the specified filepath contains the expected resource."
30-
)
31-
super().__init__(context + message + mitigation)
32-
33-
3423
class OpenlayerValidationError(OpenlayerException):
35-
def __init__(self, message, context=None, mitigation=None):
36-
if not context:
37-
context = "There are issues with some of the arguments: \n"
38-
if not mitigation:
39-
mitigation = (
40-
"Make sure to respect the datatypes and constraints specified above."
41-
)
42-
super().__init__(context + message + mitigation)
43-
44-
45-
class OpenlayerDatasetInconsistencyError(OpenlayerException):
46-
def __init__(self, message, context=None, mitigation=None):
47-
if not context:
48-
context = "There are inconsistencies between the dataset and some of the arguments: \n"
49-
if not mitigation:
50-
mitigation = "Make sure that the value specified in the argument is a column header in the dataframe or csv being uploaded."
51-
super().__init__(context + message + mitigation)
24+
def __init__(self, message):
25+
super().__init__(message)
5226

5327

5428
class OpenlayerSubscriptionPlanException(OpenlayerException):
5529
def __init__(self, message, context=None, mitigation=None):
56-
if not context:
57-
context = "You have reached your subscription plan's limits. \n"
58-
if not mitigation:
59-
mitigation = "To upgrade your plan, visit https://openlayer.com"
30+
context = context or "You have reached your subscription plan's limits. \n"
31+
mitigation = mitigation or "To upgrade your plan, visit https://openlayer.com"
6032
super().__init__(context + message + mitigation)
6133

6234

openlayer/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class ModelType(Enum):
1010
"""
1111

1212
#: For custom built models.
13-
custom = "Custom"
13+
custom = "custom"
1414
#: For models built with `fastText <https://fasttext.cc/>`_.
1515
fasttext = "fasttext"
1616
#: For models built with `Keras <https://keras.io/>`_.

openlayer/schemas.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import marshmallow as ma
22

3+
from .datasets import DatasetType
34
from .models import ModelType
45

56

@@ -55,9 +56,11 @@ class DatasetSchema(ma.Schema):
5556
max=140,
5657
),
5758
)
58-
tag_column_name = ma.fields.List(
59-
ma.fields.Str(),
60-
allow_none=True,
59+
dataset_type = ma.fields.Str(
60+
validate=ma.validate.OneOf(
61+
[dataset_type.value for dataset_type in DatasetType],
62+
error=f"`dataset_type` must be one of the supported frameworks. Check out our API reference for a full list https://reference.openlayer.com/reference/api/openlayer.DatasetType.html.\n ",
63+
),
6164
)
6265
class_names = ma.fields.List(
6366
ma.fields.Str(),
@@ -73,7 +76,6 @@ class DatasetSchema(ma.Schema):
7376
sep = ma.fields.Str()
7477
feature_names = ma.fields.List(
7578
ma.fields.Str(),
76-
allow_none=True,
7779
)
7880
text_column_name = ma.fields.Str(
7981
allow_none=True,

openlayer/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,3 @@ def remove_python_version(dir: str):
9999
dir (str): the directory to remove the file from.
100100
"""
101101
os.remove(f"{dir}/python_version")
102-
103-
104-
def copy_to_tmp_dir(dir: str) -> str:
105-
"""Copies the contents of the specified directory (`dir`) to a temporary directory.
106-
107-
Args:
108-
dir (str): the directory to copy the contents from.
109-
110-
Returns:
111-
str: the path to the temporary directory.
112-
"""
113-
tmp_dir = tempfile.mkdtemp()
114-
distutils.dir_util.copy_tree(dir, tmp_dir)
115-
116-
return tmp_dir

0 commit comments

Comments
 (0)