Skip to content

Commit 5d100ac

Browse files
authored
Merge pull request #34 from ai-forever/multi-gpu-filter
Multi-gpu filter
2 parents 04eb9f6 + d3cfd19 commit 5d100ac

24 files changed

+166
-115
lines changed

DPF/dataset_reader.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def from_config(
321321
config: DatasetConfig,
322322
**kwargs
323323
) -> DatasetProcessor:
324-
"""Creates DatasetConfig dataset
324+
"""Creates DatasetProcessor from config
325325
326326
Parameters
327327
----------
@@ -345,3 +345,32 @@ def from_config(
345345
raise ValueError(f"Unsupported config: {config}")
346346
return processor
347347

348+
def from_df(self, config: DatasetConfig, df: pd.DataFrame) -> DatasetProcessor:
349+
"""Creates DatasetProcessor from config and dataframe
350+
351+
Parameters
352+
----------
353+
config: DatasetConfig
354+
Config of DatasetConfig type
355+
df: pd.DataFrame
356+
Dataframe for DatasetProcessor.df
357+
358+
Returns
359+
-------
360+
DatasetProcessor
361+
Instance of DatasetProcessor dataset
362+
"""
363+
if isinstance(config, ShardsDatasetConfig):
364+
processor_class = ShardsDatasetProcessor
365+
elif isinstance(config, ShardedFilesDatasetConfig):
366+
processor_class = ShardedFilesDatasetProcessor
367+
elif isinstance(config, FilesDatasetConfig):
368+
processor_class = FilesDatasetProcessor
369+
else:
370+
raise ValueError(f"Unsupported config: {config}")
371+
372+
return processor_class(
373+
filesystem=self.filesystem,
374+
config=config,
375+
df=df
376+
)

DPF/filters/data_filter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ class DataFilter(ABC):
1212
Abstract class for all filters that use datalaaders.
1313
"""
1414

15-
def __init__(self, pbar: bool):
15+
def __init__(self, pbar: bool, _pbar_position: int = 0):
1616
super().__init__()
1717
self.pbar = pbar
18+
self.pbar_position = _pbar_position
1819

1920
@property
2021
@abstractmethod
@@ -66,7 +67,7 @@ def run(self, dataset: Dataset) -> pd.DataFrame:
6667
dataloader = DataLoader(dataset, collate_fn=identical_collate_fn, **self.dataloader_kwargs)
6768
df_labels = self._generate_dict_from_schema()
6869

69-
for batch in tqdm(dataloader, disable=not self.pbar):
70+
for batch in tqdm(dataloader, disable=not self.pbar, position=self.pbar_position):
7071
# drop Nans
7172
batch_filtered = [b[1] for b in batch if b[0]]
7273
if len(batch_filtered) == 0:

DPF/filters/images/aesthetic_filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def __init__(
6060
workers: int = 16,
6161
batch_size: int = 64,
6262
pbar: bool = True,
63+
_pbar_position: int = 0
6364
):
64-
super().__init__(pbar)
65+
super().__init__(pbar, _pbar_position)
6566

6667
self.num_workers = workers
6768
self.batch_size = batch_size

DPF/filters/images/aesthetic_improved_filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ def __init__(
7979
workers: int = 16,
8080
batch_size: int = 64,
8181
pbar: bool = True,
82+
_pbar_position: int = 0
8283
):
83-
super().__init__(pbar)
84+
super().__init__(pbar, _pbar_position)
8485

8586
self.num_workers = workers
8687
self.batch_size = batch_size

DPF/filters/images/blip_captioning_filter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ class BLIPCaptioningFilter(ImageFilter):
1616
BLIPCaptioningFilter class
1717
"""
1818

19-
def __init__(self, workers=16, batch_size=64, device="cuda:0", pbar=True):
20-
super().__init__(pbar)
19+
def __init__(
20+
self,
21+
workers: int = 16,
22+
batch_size: int = 64,
23+
device: str = "cuda:0",
24+
pbar: bool = True,
25+
_pbar_position: int = 0
26+
):
27+
super().__init__(pbar, _pbar_position)
2128

2229
self.num_workers = workers
2330
self.batch_size = batch_size

DPF/filters/images/cliplabels_filter.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ class CLIPLabelsFilter(ImageFilter):
3434
Batch size for model
3535
pbar: bool = True
3636
Flag for displaying progress bar
37-
38-
Attributes
39-
----------
40-
schema: List[str]
41-
List of columns to be added with this filter.
42-
dataloader_kwargs: dict:
43-
Parameters for dataloader (batch_size, num_workers, collate_fn, etc.)
4437
"""
4538

4639
def __init__(
@@ -53,8 +46,9 @@ def __init__(
5346
workers: int = 16,
5447
batch_size: int = 64,
5548
pbar: bool = True,
49+
_pbar_position: int = 0
5650
):
57-
super().__init__(pbar)
51+
super().__init__(pbar, _pbar_position)
5852

5953
if templates is None:
6054
templates = ["{}", "photo of a {}"]

DPF/filters/images/hash_filters.py

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ class PHashFilter(ImageFilter):
3535
PHashFilter class
3636
"""
3737

38-
def __init__(self, sim_hash_size: int = 8, workers: int = 16, pbar: bool = True):
39-
super().__init__(pbar)
40-
38+
def __init__(
39+
self,
40+
sim_hash_size: int = 8,
41+
workers: int = 16,
42+
pbar: bool = True,
43+
_pbar_position: int = 0
44+
):
45+
super().__init__(pbar, _pbar_position)
4146
self.num_workers = workers
4247
self.sim_hash_size = sim_hash_size
4348

@@ -68,39 +73,3 @@ def process_batch(self, batch) -> dict:
6873
df_batch_labels[f"image_phash_{self.sim_hash_size}"].extend(img_simhashes)
6974

7075
return df_batch_labels
71-
72-
73-
class MD5Filter(ImageFilter):
74-
"""
75-
MD5Filter class
76-
"""
77-
78-
def __init__(
79-
self,
80-
pbar: bool = True,
81-
workers: int = 16,
82-
):
83-
super().__init__(pbar)
84-
85-
self.num_workers = workers
86-
87-
self.schema = ["image_path", "image_md5"]
88-
self.dataloader_kwargs = {
89-
"num_workers": self.num_workers,
90-
"batch_size": 1,
91-
"drop_last": False,
92-
}
93-
94-
def preprocess(self, img_bytes: bytes, data: dict):
95-
image_path = data["image_path"]
96-
img_md5 = get_md5_hash(img_bytes)
97-
return image_path, img_md5
98-
99-
def process_batch(self, batch) -> dict:
100-
df_batch_labels = self._generate_dict_from_schema()
101-
102-
image_paths, img_md5s = list(zip(*batch))
103-
df_batch_labels["image_path"].extend(image_paths)
104-
df_batch_labels["image_md5"].extend(img_md5s)
105-
106-
return df_batch_labels

DPF/filters/images/info_filter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ class ImageInfoFilter(ImageFilter):
5151
ImageInfoFilter class
5252
"""
5353

54-
def __init__(self, workers: int = 16, pbar: bool = True):
55-
super().__init__(pbar)
56-
54+
def __init__(self, workers: int = 16, pbar: bool = True, _pbar_position: int = 0):
55+
super().__init__(pbar, _pbar_position)
5756
self.num_workers = workers
5857

5958
@property

DPF/filters/images/llava_captioning_filter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ def __init__(
2828
prompt: str = 'detailed-long',
2929
workers: int = 16,
3030
batch_size: int = 16,
31-
device="cuda:0",
32-
pbar=True
31+
device: str = "cuda:0",
32+
pbar: bool = True,
33+
_pbar_position: int = 0
3334
):
34-
super().__init__(pbar)
35+
super().__init__(pbar, _pbar_position)
3536
self.batch_size = batch_size
3637
self.num_workers = workers
3738
self.device = device

DPF/filters/images/nsfw_filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def __init__(
7474
batch_size: int = 64,
7575
device: str = "cuda:0",
7676
pbar: bool = True,
77+
_pbar_position: int = 0
7778
):
78-
super().__init__(pbar)
79+
super().__init__(pbar, _pbar_position)
7980

8081
self.num_workers = workers
8182
self.batch_size = batch_size

0 commit comments

Comments
 (0)