Skip to content

Commit 453956f

Browse files
committed
refactor: refactor base filter classes
1 parent e364952 commit 453956f

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

DPF/filters/column_filter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List
2+
from typing import List, Dict, Any
33

44
import numpy as np
55
import pandas as pd
@@ -28,11 +28,11 @@ def schema(self) -> List[str]:
2828
pass
2929

3030
@abstractmethod
31-
def process(self, row: dict) -> tuple:
31+
def process(self, row: Dict[str, Any]) -> List[Any]:
3232
pass
3333

34-
def __call__(self, df: pd.DataFrame) -> np.ndarray:
34+
def __call__(self, df: pd.DataFrame) -> List[List[Any]]:
3535
pandarallel.initialize(nb_workers=self.workers)
36-
res = np.array(list(df[self.columns_to_process].parallel_apply(self.process, axis=1)))
36+
res = list(df[self.columns_to_process].parallel_apply(self.process, axis=1))
3737
return res
3838

DPF/filters/data_filter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Union
2+
from typing import Any, Dict, List, Union, Tuple
33

44
import pandas as pd
55
from torch.utils.data import DataLoader, Dataset
@@ -51,32 +51,39 @@ def metadata_columns(self) -> List[str]:
5151
pass
5252

5353
@abstractmethod
54-
def preprocess(self, modality2data: ModalityToDataMapping, metadata: Dict[str, str]):
54+
def preprocess_data(
55+
self,
56+
modality2data: ModalityToDataMapping,
57+
metadata: Dict[str, str]
58+
) -> Any:
5559
pass
5660

5761
@abstractmethod
58-
def process_batch(self, batch) -> dict:
62+
def process_batch(self, batch: List[Any]) -> Dict[str, List[Any]]:
5963
pass
6064

6165
@staticmethod
62-
def _add_values_from_batch(main_dict: dict, batch_dict: dict):
66+
def _add_values_from_batch(
67+
main_dict: Dict[str, List[Any]],
68+
batch_dict: Dict[str, List[Any]]
69+
) -> None:
6370
for k, v in batch_dict.items():
6471
main_dict[k].extend(v)
6572

66-
def _generate_dict_from_schema(self):
73+
def _get_dict_from_schema(self) -> Dict[str, List[Any]]:
6774
return {i: [] for i in self.schema}
6875

69-
def run(self, dataset: Dataset) -> pd.DataFrame:
76+
def run(self, dataset: Dataset[Tuple[bool, Any]]) -> pd.DataFrame:
7077
dataloader = DataLoader(dataset, collate_fn=identical_collate_fn, **self.dataloader_kwargs)
71-
df_labels = self._generate_dict_from_schema()
78+
filter_results = self._get_dict_from_schema()
7279

7380
for batch in tqdm(dataloader, disable=not self.pbar, position=self.pbar_position):
7481
# drop Nans
7582
batch_filtered = [b[1] for b in batch if b[0]]
7683
if len(batch_filtered) == 0:
7784
continue
7885

79-
df_batch_labels = self.process_batch(batch_filtered)
80-
self._add_values_from_batch(df_labels, df_batch_labels)
86+
filter_results_batch = self.process_batch(batch_filtered)
87+
self._add_values_from_batch(filter_results, filter_results_batch)
8188

82-
return pd.DataFrame(df_labels)
89+
return pd.DataFrame(filter_results)

DPF/filters/multigpu_filter.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from multiprocessing import Manager, Process
2-
from typing import List, Type, Union
2+
from typing import List, Type, Union, Dict, Any
33

44
import numpy as np
55
import pandas as pd
@@ -21,13 +21,13 @@ def run_one_process(
2121
index: pd.Series,
2222
results: List[pd.DataFrame],
2323
filter_class: Type[DataFilter],
24-
filter_kwargs: dict,
25-
device: str,
26-
filter_run_kwargs: dict
27-
):
24+
filter_kwargs: Dict[str, Any],
25+
device: Union[str, torch.device],
26+
filter_run_kwargs: Dict[str, Any]
27+
) -> None:
2828
reader = DatasetReader(filesystem=fs)
2929
processor = reader.from_df(config, df)
30-
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device)
30+
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore
3131
processor.apply_data_filter(datafilter, **filter_run_kwargs)
3232
res = processor.df
3333
res.set_index(index, inplace=True)
@@ -41,16 +41,22 @@ class MultiGPUDataFilter:
4141

4242
def __init__(
4343
self,
44-
devices: List[Union[torch.device | str]],
45-
filter_class: type,
46-
filter_params: dict
44+
devices: List[Union[torch.device, str]],
45+
datafilter_class: Type[DataFilter],
46+
datafilter_params: Dict[str, Any]
4747
):
48-
self.filter_class = filter_class
49-
self.filter_params = filter_params
48+
self.filter_class = datafilter_class
49+
self.filter_params = datafilter_params
5050
self.devices = devices
5151
self.num_parts = len(devices)
5252

53-
def run(self, df: pd.DataFrame, config: DatasetConfig, fs: FileSystem, filter_run_kwargs: dict) -> pd.DataFrame:
53+
def run(
54+
self,
55+
df: pd.DataFrame,
56+
config: DatasetConfig,
57+
fs: FileSystem,
58+
filter_run_kwargs: Dict[str, Any]
59+
) -> pd.DataFrame:
5460
manager = Manager()
5561
shared_results = manager.list()
5662

@@ -63,7 +69,7 @@ def run(self, df: pd.DataFrame, config: DatasetConfig, fs: FileSystem, filter_ru
6369
fs,
6470
df_splits[i],
6571
i,
66-
df_splits[i].index,
72+
df_splits[i].index, # type: ignore
6773
shared_results,
6874
self.filter_class,
6975
self.filter_params,

0 commit comments

Comments
 (0)