@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532532 self ._create_bbox_txt (base_folder , num_images )
533533 self ._create_landmarks_txt (base_folder , num_images )
534534
535-  return  dict (num_examples = num_images_per_split [config ["split" ]], attr_names = attr_names )
535+  num_samples  =  num_images_per_split .get (config ["split" ], 0 ) if  isinstance (config ["split" ], str ) else  0 
536+  return  dict (num_examples = num_samples , attr_names = attr_names )
536537
537538 def  _create_split_txt (self , root ):
538539 num_images_per_split  =  dict (train = 4 , valid = 3 , test = 2 )
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635636 with  self .create_dataset (target_type = target_type , transform = v2 .Resize (size = expected_size )) as  (dataset , _ ):
636637 datasets_utils .check_transforms_v2_wrapper_spawn (dataset , expected_size = expected_size )
637638
639+  def  test_invalid_split_list (self ):
640+  with  pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'list'>." ):
641+  with  self .create_dataset (split = [1 ]):
642+  pass 
643+ 
644+  def  test_invalid_split_int (self ):
645+  with  pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'int'>." ):
646+  with  self .create_dataset (split = 1 ):
647+  pass 
648+ 
649+  def  test_invalid_split_value (self ):
650+  with  pytest .raises (
651+  ValueError ,
652+  match = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." .format (
653+  value = "invalid" ,
654+  arg = "split" ,
655+  valid_values = ("train" , "valid" , "test" , "all" ),
656+  ),
657+  ):
658+  with  self .create_dataset (split = "invalid" ):
659+  pass 
660+ 
638661
639662class  VOCSegmentationTestCase (datasets_utils .ImageDatasetTestCase ):
640663 DATASET_CLASS  =  datasets .VOCSegmentation 
0 commit comments