11from multiprocessing import Manager , Process
2- from typing import List , Type , Union
2+ from typing import List , Type , Union , Dict , Any
33
44import numpy as np
55import pandas as pd
@@ -21,13 +21,13 @@ def run_one_process(
2121 index : pd .Series ,
2222 results : List [pd .DataFrame ],
2323 filter_class : Type [DataFilter ],
24- filter_kwargs : dict ,
25- device : str ,
26- filter_run_kwargs : dict
27- ):
24+ filter_kwargs : Dict [ str , Any ] ,
25+ device : Union [ str , torch . device ] ,
26+ filter_run_kwargs : Dict [ str , Any ]
27+ ) -> None :
2828 reader = DatasetReader (filesystem = fs )
2929 processor = reader .from_df (config , df )
30- datafilter = filter_class (** filter_kwargs , _pbar_position = i , device = device )
30+ datafilter = filter_class (** filter_kwargs , _pbar_position = i , device = device ) # type: ignore
3131 processor .apply_data_filter (datafilter , ** filter_run_kwargs )
3232 res = processor .df
3333 res .set_index (index , inplace = True )
@@ -41,16 +41,22 @@ class MultiGPUDataFilter:
4141
4242 def __init__ (
4343 self ,
44- devices : List [Union [torch .device | str ]],
45- filter_class : type ,
46- filter_params : dict
44+ devices : List [Union [torch .device , str ]],
45+ datafilter_class : Type [ DataFilter ] ,
46+ datafilter_params : Dict [ str , Any ]
4747 ):
48- self .filter_class = filter_class
49- self .filter_params = filter_params
48+ self .filter_class = datafilter_class
49+ self .filter_params = datafilter_params
5050 self .devices = devices
5151 self .num_parts = len (devices )
5252
53- def run (self , df : pd .DataFrame , config : DatasetConfig , fs : FileSystem , filter_run_kwargs : dict ) -> pd .DataFrame :
53+ def run (
54+ self ,
55+ df : pd .DataFrame ,
56+ config : DatasetConfig ,
57+ fs : FileSystem ,
58+ filter_run_kwargs : Dict [str , Any ]
59+ ) -> pd .DataFrame :
5460 manager = Manager ()
5561 shared_results = manager .list ()
5662
@@ -63,7 +69,7 @@ def run(self, df: pd.DataFrame, config: DatasetConfig, fs: FileSystem, filter_ru
6369 fs ,
6470 df_splits [i ],
6571 i ,
66- df_splits [i ].index ,
72+ df_splits [i ].index , # type: ignore
6773 shared_results ,
6874 self .filter_class ,
6975 self .filter_params ,
0 commit comments