66from  DPF .filters  import  ColumnFilter , DataFilter 
77from  DPF .filters .multigpu_filter  import  MultiGPUDataFilter 
88from  DPF .processors  import  DatasetProcessor 
9+ from  DPF .transforms  import  BaseFilesTransforms 
910from  DPF .utils .logger  import  init_logger , init_stdout_logger 
1011
1112from  .pipeline_stages  import  (
1415 FilterPipelineStage ,
1516 PipelineStage ,
1617 ShufflePipelineStage ,
18+  TransformPipelineStage ,
1719)
1820from  .types  import  OnErrorOptions 
1921
@@ -39,25 +41,25 @@ def add_datafilter(
3941 datafilter : type [DataFilter ],
4042 datafilter_kwargs : dict [str , Any ],
4143 devices : Optional [list [str ]] =  None ,
42-  processor_run_kwargs : Optional [dict [str , Any ]] =  None ,
44+  processor_apply_kwargs : Optional [dict [str , Any ]] =  None ,
4345 on_error : OnErrorOptions  =  "stop" ,
4446 skip_if_columns_exist : bool  =  True 
4547 ) ->  None :
46-  if  processor_run_kwargs  is  None :
47-  processor_run_kwargs  =  {}
48+  if  processor_apply_kwargs  is  None :
49+  processor_apply_kwargs  =  {}
4850
4951 if  devices  is  None :
5052 stage  =  FilterPipelineStage (
5153 'datafilter' , filter_class = datafilter ,
52-  filter_kwargs = datafilter_kwargs , processor_run_kwargs = processor_run_kwargs ,
54+  filter_kwargs = datafilter_kwargs , processor_apply_kwargs = processor_apply_kwargs ,
5355 skip_if_columns_exist = skip_if_columns_exist 
5456 )
5557 elif  len (devices ) ==  0 :
5658 new_kwargs  =  datafilter_kwargs .copy ()
5759 new_kwargs ['device' ] =  devices [0 ]
5860 stage  =  FilterPipelineStage (
5961 'datafilter' , filter_class = datafilter ,
60-  filter_kwargs = new_kwargs , processor_run_kwargs = processor_run_kwargs ,
62+  filter_kwargs = new_kwargs , processor_apply_kwargs = processor_apply_kwargs ,
6163 skip_if_columns_exist = skip_if_columns_exist 
6264 )
6365 else :
@@ -68,7 +70,7 @@ def add_datafilter(
6870 "datafilter_class" : datafilter ,
6971 "datafilter_params" : datafilter_kwargs 
7072 },
71-  processor_run_kwargs = processor_run_kwargs ,
73+  processor_apply_kwargs = processor_apply_kwargs ,
7274 skip_if_columns_exist = skip_if_columns_exist 
7375 )
7476
@@ -80,16 +82,16 @@ def add_columnfilter(
8082 self ,
8183 columnfilter : type [ColumnFilter ],
8284 columnfilter_kwargs : dict [str , Any ],
83-  processor_run_kwargs : Optional [dict [str , Any ]] =  None ,
85+  processor_apply_kwargs : Optional [dict [str , Any ]] =  None ,
8486 on_error : OnErrorOptions  =  "stop" ,
8587 skip_if_columns_exist : bool  =  True 
8688 ) ->  None :
87-  if  processor_run_kwargs  is  None :
88-  processor_run_kwargs  =  {}
89+  if  processor_apply_kwargs  is  None :
90+  processor_apply_kwargs  =  {}
8991
9092 stage  =  FilterPipelineStage (
9193 'columnfilter' , filter_class = columnfilter ,
92-  filter_kwargs = columnfilter_kwargs , processor_run_kwargs = processor_run_kwargs ,
94+  filter_kwargs = columnfilter_kwargs , processor_apply_kwargs = processor_apply_kwargs ,
9395 skip_if_columns_exist = skip_if_columns_exist 
9496 )
9597
@@ -123,6 +125,21 @@ def add_dataframe_filter(
123125 PipelineStageRunner (stage , on_error = on_error )
124126 )
125127
128+  def  add_transforms (
129+  self ,
130+  transforms_class : type [BaseFilesTransforms ],
131+  transforms_kwargs : dict [str , Any ],
132+  processor_apply_kwargs : Optional [dict [str , Any ]] =  None ,
133+  on_error : OnErrorOptions  =  "stop" 
134+  ) ->  None :
135+  stage  =  TransformPipelineStage (
136+  transforms_class , transforms_kwargs ,
137+  processor_apply_kwargs = processor_apply_kwargs 
138+  )
139+  self .stages .append (
140+  PipelineStageRunner (stage , on_error = on_error )
141+  )
142+ 
126143 def  _log_dataset_info (self , processor : DatasetProcessor ) ->  None :
127144 self .logger .info (f'Dataset path: { processor .config .path }  ' )
128145 self .logger .info (f'Dataset modalities: { processor .modalities }  ' )
0 commit comments