1010import  yaml 
1111
1212from  . import  api , exceptions , schemas , utils , validators 
13- from  .datasets  import  Dataset 
13+ from  .datasets  import  Dataset ,  DatasetType 
1414from  .models  import  Model 
1515from  .projects  import  Project 
1616from  .tasks  import  TaskType 
@@ -274,8 +274,8 @@ def add_model(
274274
275275 if  failed_validations :
276276 raise  exceptions .OpenlayerValidationError (
277-  context = "There are issues with the model package, as specified above . \n " , 
278-  mitigation = "Make sure to fix all of them before uploading  the model ." ,
277+  "There are issues with the model package. \n " 
278+  "Make sure to fix all of the issues listed above before  the upload ." ,
279279 ) from  None 
280280
281281 # ------ Start of temporary workaround for the arguments in the payload ------ # 
@@ -307,6 +307,7 @@ def add_model(
307307 utils .remove_python_version (model_package_dir )
308308
309309 # Make sure the resulting model package is less than 2 GB 
310+  # TODO: this should depend on the subscription plan 
310311 if  float (os .path .getsize ("model" )) /  1e9  >  2 :
311312 raise  exceptions .OpenlayerResourceError (
312313 context = "There's an issue with the specified `model_package_dir`. \n " ,
@@ -342,13 +343,15 @@ def add_dataset(
342343 file_path : str ,
343344 class_names : List [str ],
344345 label_column_name : str ,
346+  dataset_type : DatasetType  =  DatasetType .Validation ,
345347 feature_names : List [str ] =  [],
346348 text_column_name : Optional [str ] =  None ,
347349 categorical_feature_names : List [str ] =  [],
348350 tag_column_name : Optional [str ] =  None ,
349351 language : str  =  "en" ,
350352 sep : str  =  "," ,
351353 commit_message : Optional [str ] =  None ,
354+  dataset_config_file_path : Optional [str ] =  None ,
352355 project_id : str  =  None ,
353356 ) ->  Dataset :
354357 r"""Uploads a dataset to the Openlayer platform (from a csv). 
@@ -365,6 +368,9 @@ def add_dataset(
365368
366369 .. important:: 
367370 The labels in this column must be zero-indexed integer values. 
371+  dataset_type : :obj:`DatasetType`, default :obj:`DatasetType.Validation` 
372+  Type of dataset. E.g. :obj:`DatasetType.Validation` or 
373+  :obj:`DatasetType.Training`. 
368374 feature_names : List[str], default [] 
369375 List of input feature names. Only applicable if your ``task_type`` is 
370376 :obj:`TaskType.TabularClassification` or :obj:`TaskType.TabularRegression`. 
@@ -488,154 +494,36 @@ def add_dataset(
488494 ... ) 
489495 >>> dataset.to_dict() 
490496 """ 
491-  # ---------------------------- Schema validations ---------------------------- # 
492-  if  task_type  not  in 
493-  TaskType .TabularClassification ,
494-  TaskType .TextClassification ,
495-  ]:
496-  raise  exceptions .OpenlayerValidationError (
497-  "`task_type` must be either TaskType.TabularClassification or " 
498-  "TaskType.TextClassification. \n " 
499-  ) from  None 
500-  dataset_schema  =  schemas .DatasetSchema ()
501-  try :
502-  dataset_schema .load (
503-  {
504-  "file_path" : file_path ,
505-  "commit_message" : commit_message ,
506-  "class_names" : class_names ,
507-  "label_column_name" : label_column_name ,
508-  "tag_column_name" : tag_column_name ,
509-  "language" : language ,
510-  "sep" : sep ,
511-  "feature_names" : feature_names ,
512-  "text_column_name" : text_column_name ,
513-  "categorical_feature_names" : categorical_feature_names ,
514-  }
515-  )
516-  except  ma .ValidationError  as  err :
497+  # ---------------------------- Dataset validations --------------------------- # 
498+  # TODO: re-think the way the arguments are passed for the dataset upload 
499+  dataset_config  =  None 
500+  if  dataset_config_file_path  is  None :
501+  dataset_config  =  {
502+  "file_path" : file_path ,
503+  "class_names" : class_names ,
504+  "label_column_name" : label_column_name ,
505+  "dataset_type" : dataset_type .value ,
506+  "feature_names" : feature_names ,
507+  "text_column_name" : text_column_name ,
508+  "categorical_feature_names" : categorical_feature_names ,
509+  "language" : language ,
510+  "sep" : sep ,
511+  }
512+ 
513+  dataset_validator  =  validators .DatasetValidator (
514+  dataset_config_file_path = dataset_config_file_path ,
515+  dataset_config = dataset_config ,
516+  dataset_file_path = file_path ,
517+  )
518+  failed_validations  =  dataset_validator .validate ()
519+ 
520+  if  failed_validations :
517521 raise  exceptions .OpenlayerValidationError (
518-  self ._format_error_message (err )
522+  "There are issues with the dataset and its config. \n " 
523+  "Make sure to fix all of the issues listed above before the upload." ,
519524 ) from  None 
520525
521-  # --------------------------- Resource validations --------------------------- # 
522-  exp_file_path  =  os .path .expanduser (file_path )
523526 object_name  =  "original.csv" 
524-  if  not  os .path .isfile (exp_file_path ):
525-  raise  exceptions .OpenlayerResourceError (
526-  f"File at path `{ file_path } \n " 
527-  ) from  None 
528- 
529-  with  open (exp_file_path , "rt" ) as  f :
530-  reader  =  csv .reader (f , delimiter = sep )
531-  headers  =  next (reader )
532-  row_count  =  sum (1  for  _  in  reader )
533- 
534-  df  =  pd .read_csv (file_path , sep = sep )
535- 
536-  # Checking for null values 
537-  if  df .isnull ().values .any ():
538-  raise  exceptions .OpenlayerResourceError (
539-  context = "There's an issue with the specified dataset. \n " ,
540-  message = "The dataset contains null values, which is currently " 
541-  "not supported. \n " ,
542-  mitigation = "Make sure to upload a dataset without null values." ,
543-  ) from  None 
544- 
545-  # Validating if the labels are zero indexed ints 
546-  unique_labels  =  set (df [label_column_name ].unique ())
547-  zero_indexed_set  =  set (range (len (class_names )))
548-  if  unique_labels  !=  zero_indexed_set :
549-  raise  exceptions .OpenlayerResourceError (
550-  context = f"There's an issue with values in the column " 
551-  f"`{ label_column_name } \n " ,
552-  message = f"The labels in `{ label_column_name }  
553-  "zero-indexed integer values. \n " ,
554-  mitigation = "Make sure to upload a dataset with zero-indexed " 
555-  "integer labels that match the list in `class_names`. " 
556-  f"For example, the class `{ class_names [0 ]}  
557-  "represented as a 0 in the dataset, the class " 
558-  f"`{ class_names [1 ]}  ,
559-  ) from  None 
560- 
561-  # Validating the column dtypes 
562-  supported_dtypes  =  {"float32" , "float64" , "int32" , "int64" , "object" }
563-  error_msg  =  "" 
564-  for  col  in  df :
565-  dtype  =  df [col ].dtype .name 
566-  if  dtype  not  in supported_dtypes :
567-  error_msg  +=  f"- Column `{ col } { dtype } \n " 
568-  if  error_msg :
569-  raise  exceptions .OpenlayerResourceError (
570-  context = "There is an issue with some of the columns dtypes.\n " ,
571-  message = error_msg ,
572-  mitigation = f"The supported dtypes are { supported_dtypes }  
573-  "Make sure to cast the above columns to a supported dtype." ,
574-  ) from  None 
575-  # ------------------ Resource-schema consistency validations ----------------- # 
576-  # Label column validations 
577-  try :
578-  headers .index (label_column_name )
579-  except  ValueError :
580-  raise  exceptions .OpenlayerDatasetInconsistencyError (
581-  f"`{ label_column_name }  
582-  "in the dataset. \n " 
583-  ) from  None 
584- 
585-  if  len (unique_labels ) >  len (class_names ):
586-  raise  exceptions .OpenlayerDatasetInconsistencyError (
587-  f"There are { len (unique_labels )}  
588-  f"but only { len (class_names )} \n " ,
589-  mitigation = f"Make sure that there are at most { len (class_names )}  
590-  "classes in your dataset." ,
591-  ) from  None 
592- 
593-  # Feature validations 
594-  try :
595-  if  text_column_name :
596-  feature_names  =  [text_column_name ]
597-  for  feature_name  in  feature_names :
598-  headers .index (feature_name )
599-  except  ValueError :
600-  if  text_column_name :
601-  raise  exceptions .OpenlayerDatasetInconsistencyError (
602-  f"`{ text_column_name }  
603-  "the dataset. \n " 
604-  ) from  None 
605-  else :
606-  features_not_in_dataset  =  [
607-  feature  for  feature  in  feature_names  if  feature  not  in headers 
608-  ]
609-  raise  exceptions .OpenlayerDatasetInconsistencyError (
610-  f"Features { features_not_in_dataset }  
611-  "are not in the dataset. \n " 
612-  ) from  None 
613-  # Tag column validation 
614-  try :
615-  if  tag_column_name :
616-  headers .index (tag_column_name )
617-  except  ValueError :
618-  raise  exceptions .OpenlayerDatasetInconsistencyError (
619-  f"`{ tag_column_name }  
620-  "the dataset. \n " 
621-  ) from  None 
622- 
623-  # ----------------------- Subscription plan validations ---------------------- # 
624-  if  row_count  >  self .subscription_plan ["datasetRowCount" ]:
625-  raise  exceptions .OpenlayerSubscriptionPlanException (
626-  f"The dataset your are trying to upload contains { row_count }  
627-  "which exceeds your plan's limit of " 
628-  f"{ self .subscription_plan ['datasetRowCount' ]} \n " 
629-  ) from  None 
630-  if  task_type  ==  TaskType .TextClassification :
631-  max_text_size  =  df [text_column_name ].str .len ().max ()
632-  if  max_text_size  >  1000 :
633-  raise  exceptions .OpenlayerSubscriptionPlanException (
634-  "The dataset you are trying to upload contains rows with " 
635-  f"{ max_text_size }  
636-  "limit." 
637-  ) from  None 
638- 
639527 endpoint  =  f"projects/{ project_id }  
640528 payload  =  dict (
641529 commitMessage = commit_message ,
@@ -666,13 +554,15 @@ def add_dataframe(
666554 df : pd .DataFrame ,
667555 class_names : List [str ],
668556 label_column_name : str ,
557+  dataset_type : DatasetType  =  DatasetType .Validation ,
669558 feature_names : List [str ] =  [],
670559 text_column_name : Optional [str ] =  None ,
671560 categorical_feature_names : List [str ] =  [],
672561 commit_message : Optional [str ] =  None ,
673562 tag_column_name : Optional [str ] =  None ,
674563 language : str  =  "en" ,
675564 project_id : str  =  None ,
565+  dataset_config_file_path : Optional [str ] =  None ,
676566 ) ->  Dataset :
677567 r"""Uploads a dataset to the Openlayer platform (from a pandas DataFrame). 
678568
@@ -688,6 +578,9 @@ def add_dataframe(
688578
689579 .. important:: 
690580 The labels in this column must be zero-indexed integer values. 
581+  dataset_type : :obj:`DatasetType`, default :obj:`DatasetType.Validation` 
582+  Type of dataset. E.g. :obj:`DatasetType.Validation` or 
583+  :obj:`DatasetType.Training`. 
691584 feature_names : List[str], default [] 
692585 List of input feature names. Only applicable if your ``task_type`` is 
693586 :obj:`TaskType.TabularClassification` or :obj:`TaskType.TabularRegression`. 
@@ -820,13 +713,15 @@ def add_dataframe(
820713 task_type = task_type ,
821714 class_names = class_names ,
822715 label_column_name = label_column_name ,
716+  dataset_type = dataset_type ,
823717 text_column_name = text_column_name ,
824718 commit_message = commit_message ,
825719 tag_column_name = tag_column_name ,
826720 language = language ,
827721 feature_names = feature_names ,
828722 categorical_feature_names = categorical_feature_names ,
829723 project_id = project_id ,
724+  dataset_config_file_path = dataset_config_file_path ,
830725 )
831726
832727 @staticmethod  
0 commit comments