Skip to content

Commit 23de547

Browse files
committed
feat: add transforms in pipelines
1 parent ffae303 commit 23de547

File tree

2 files changed

+63
-16
lines changed

2 files changed

+63
-16
lines changed

DPF/pipelines/filter_pipeline.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from DPF.filters import ColumnFilter, DataFilter
77
from DPF.filters.multigpu_filter import MultiGPUDataFilter
88
from DPF.processors import DatasetProcessor
9+
from DPF.transforms import BaseFilesTransforms
910
from DPF.utils.logger import init_logger, init_stdout_logger
1011

1112
from .pipeline_stages import (
@@ -14,6 +15,7 @@
1415
FilterPipelineStage,
1516
PipelineStage,
1617
ShufflePipelineStage,
18+
TransformPipelineStage,
1719
)
1820
from .types import OnErrorOptions
1921

@@ -39,25 +41,25 @@ def add_datafilter(
3941
datafilter: type[DataFilter],
4042
datafilter_kwargs: dict[str, Any],
4143
devices: Optional[list[str]] = None,
42-
processor_run_kwargs: Optional[dict[str, Any]] = None,
44+
processor_apply_kwargs: Optional[dict[str, Any]] = None,
4345
on_error: OnErrorOptions = "stop",
4446
skip_if_columns_exist: bool = True
4547
) -> None:
46-
if processor_run_kwargs is None:
47-
processor_run_kwargs = {}
48+
if processor_apply_kwargs is None:
49+
processor_apply_kwargs = {}
4850

4951
if devices is None:
5052
stage = FilterPipelineStage(
5153
'datafilter', filter_class=datafilter,
52-
filter_kwargs=datafilter_kwargs, processor_run_kwargs=processor_run_kwargs,
54+
filter_kwargs=datafilter_kwargs, processor_apply_kwargs=processor_apply_kwargs,
5355
skip_if_columns_exist=skip_if_columns_exist
5456
)
5557
elif len(devices) == 0:
5658
new_kwargs = datafilter_kwargs.copy()
5759
new_kwargs['device'] = devices[0]
5860
stage = FilterPipelineStage(
5961
'datafilter', filter_class=datafilter,
60-
filter_kwargs=new_kwargs, processor_run_kwargs=processor_run_kwargs,
62+
filter_kwargs=new_kwargs, processor_apply_kwargs=processor_apply_kwargs,
6163
skip_if_columns_exist=skip_if_columns_exist
6264
)
6365
else:
@@ -68,7 +70,7 @@ def add_datafilter(
6870
"datafilter_class": datafilter,
6971
"datafilter_params": datafilter_kwargs
7072
},
71-
processor_run_kwargs=processor_run_kwargs,
73+
processor_apply_kwargs=processor_apply_kwargs,
7274
skip_if_columns_exist=skip_if_columns_exist
7375
)
7476

@@ -80,16 +82,16 @@ def add_columnfilter(
8082
self,
8183
columnfilter: type[ColumnFilter],
8284
columnfilter_kwargs: dict[str, Any],
83-
processor_run_kwargs: Optional[dict[str, Any]] = None,
85+
processor_apply_kwargs: Optional[dict[str, Any]] = None,
8486
on_error: OnErrorOptions = "stop",
8587
skip_if_columns_exist: bool = True
8688
) -> None:
87-
if processor_run_kwargs is None:
88-
processor_run_kwargs = {}
89+
if processor_apply_kwargs is None:
90+
processor_apply_kwargs = {}
8991

9092
stage = FilterPipelineStage(
9193
'columnfilter', filter_class=columnfilter,
92-
filter_kwargs=columnfilter_kwargs, processor_run_kwargs=processor_run_kwargs,
94+
filter_kwargs=columnfilter_kwargs, processor_apply_kwargs=processor_apply_kwargs,
9395
skip_if_columns_exist=skip_if_columns_exist
9496
)
9597

@@ -123,6 +125,21 @@ def add_dataframe_filter(
123125
PipelineStageRunner(stage, on_error=on_error)
124126
)
125127

128+
def add_transforms(
129+
self,
130+
transforms_class: type[BaseFilesTransforms],
131+
transforms_kwargs: dict[str, Any],
132+
processor_apply_kwargs: Optional[dict[str, Any]] = None,
133+
on_error: OnErrorOptions = "stop"
134+
) -> None:
135+
stage = TransformPipelineStage(
136+
transforms_class, transforms_kwargs,
137+
processor_apply_kwargs=processor_apply_kwargs
138+
)
139+
self.stages.append(
140+
PipelineStageRunner(stage, on_error=on_error)
141+
)
142+
126143
def _log_dataset_info(self, processor: DatasetProcessor) -> None:
127144
self.logger.info(f'Dataset path: {processor.config.path}')
128145
self.logger.info(f'Dataset modalities: {processor.modalities}')

DPF/pipelines/pipeline_stages.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import logging
22
from abc import ABC, abstractmethod
3-
from typing import Any, Callable, Union
3+
from typing import Any, Callable, Optional, Union
44

55
import pandas as pd
66

77
from DPF.filters import ColumnFilter, DataFilter
88
from DPF.filters.multigpu_filter import MultiGPUDataFilter
99
from DPF.processors import DatasetProcessor
10+
from DPF.transforms import BaseFilesTransforms
1011

1112
from .types import FilterTypes
1213

@@ -66,13 +67,17 @@ def __init__(
6667
filter_type: FilterTypes,
6768
filter_class: Union[type[DataFilter], type[ColumnFilter], type[MultiGPUDataFilter]],
6869
filter_kwargs: dict[str, Any],
69-
processor_run_kwargs: dict[str, Any],
70+
processor_apply_kwargs: Optional[dict[str, Any]] = None,
7071
skip_if_columns_exist: bool = True
7172
):
7273
self.filter_type = filter_type
7374
self.filter_class = filter_class
7475
self.filter_kwargs = filter_kwargs
75-
self.processor_run_kwargs = processor_run_kwargs
76+
77+
self.processor_apply_kwargs = processor_apply_kwargs
78+
if self.processor_apply_kwargs is None:
79+
self.processor_apply_kwargs = {}
80+
7681
self.skip_if_columns_exist = skip_if_columns_exist
7782

7883
@property
@@ -96,10 +101,35 @@ def run(self, processor: DatasetProcessor, logger: logging.Logger) -> None:
96101
processor.df.drop(columns=columns_to_be_added, inplace=True, errors='ignore')
97102

98103
if self.filter_type == 'datafilter':
99-
processor.apply_data_filter(filter_obj, **self.processor_run_kwargs) # type: ignore
104+
processor.apply_data_filter(filter_obj, **self.processor_apply_kwargs) # type: ignore
100105
elif self.filter_type == 'columnfilter':
101-
processor.apply_column_filter(filter_obj, **self.processor_run_kwargs) # type: ignore
106+
processor.apply_column_filter(filter_obj, **self.processor_apply_kwargs) # type: ignore
102107
elif self.filter_type == 'multigpufilter':
103-
processor.apply_multi_gpu_data_filter(filter_obj, **self.processor_run_kwargs)
108+
processor.apply_multi_gpu_data_filter(filter_obj, **self.processor_apply_kwargs) # type: ignore
104109
else:
105110
raise ValueError(f"Unknown filter type: {self.filter_type}")
111+
112+
113+
class TransformPipelineStage(PipelineStage):
114+
115+
def __init__(
116+
self,
117+
transforms_class: type[BaseFilesTransforms],
118+
transforms_kwargs: dict[str, Any],
119+
processor_apply_kwargs: Optional[dict[str, Any]] = None,
120+
):
121+
self.transforms_class = transforms_class
122+
self.transforms_kwargs = transforms_kwargs
123+
124+
self.processor_apply_kwargs = processor_apply_kwargs
125+
if self.processor_apply_kwargs is None:
126+
self.processor_apply_kwargs = {}
127+
128+
@property
129+
def stage_name(self) -> str:
130+
return f"TransformPipelineStage(transforms_class={self.transforms_class}, transforms_kwargs={self.transforms_kwargs})"
131+
132+
def run(self, processor: DatasetProcessor, logger: logging.Logger) -> None:
133+
transforms = self.transforms_class(**self.transforms_kwargs)
134+
135+
processor.apply_transform(transforms, **self.processor_apply_kwargs) # type: ignore

0 commit comments

Comments
 (0)