Skip to content

Commit 977a232

Browse files
committed
fix: fix error with multiprocessing in multigpu filter
1 parent c9e9552 commit 977a232

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

DPF/filters/data_filter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import multiprocessing
2+
import sys
13
from abc import ABC, abstractmethod
24
from typing import Any
35

@@ -19,6 +21,7 @@ def __init__(self, pbar: bool, _pbar_position: int = 0):
1921
super().__init__()
2022
self.pbar = pbar
2123
self.pbar_position = _pbar_position
24+
self._created_by_multigpu_data_filter = False
2225

2326
@property
2427
def schema(self) -> list[str]:
@@ -120,7 +123,15 @@ def run(self, dataset: Dataset[tuple[bool, Any]]) -> pd.DataFrame:
120123
pd.DataFrame
121124
Dataframe with columns from schema property
122125
"""
123-
dataloader = DataLoader(dataset, collate_fn=identical_collate_fn, **self.dataloader_kwargs)
126+
multiprocessing_context = None
127+
if self._created_by_multigpu_data_filter and sys.platform not in {'win32', 'darwin'}:
128+
multiprocessing_context = multiprocessing.get_context('fork')
129+
130+
dataloader = DataLoader(
131+
dataset, collate_fn=identical_collate_fn,
132+
multiprocessing_context=multiprocessing_context,
133+
**self.dataloader_kwargs
134+
)
124135
filter_results = self._get_dict_from_schema()
125136

126137
for batch in tqdm(dataloader, disable=not self.pbar, position=self.pbar_position):

DPF/filters/multigpu_filter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def run_one_process(
2929
reader = DatasetReader(connector=connector)
3030
processor = reader.from_df(config, df)
3131
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore
32+
datafilter._created_by_multigpu_data_filter = True
3233
processor.apply_data_filter(datafilter, **filter_run_kwargs)
3334
res = processor.df
3435
res.set_index(index, inplace=True)
@@ -118,8 +119,8 @@ def run(
118119
)
119120

120121
processes = []
122+
context = multiprocessing.get_context('spawn')
121123
for param in params:
122-
context = multiprocessing.get_context('spawn')
123124
p = context.Process(target=run_one_process, args=param)
124125
p.start()
125126
processes.append(p)

DPF/filters/videos/lita_filter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit)
8383
self.tokenizer, self.model, self.processor, self.context_len = pretrainers
8484

85+
self.model_num_frames = self.model.config.num_frames
86+
8587
self.conv_mode = "llava_v1"
8688
self.conv = conv_templates[self.conv_mode].copy()
8789

@@ -118,7 +120,7 @@ def preprocess_data(
118120
) -> Any:
119121
key = metadata[self.key_column]
120122
video_file = BytesIO(modality2data['video'])
121-
video_file = load_video(video_file, self.processor, self.model.config.num_frames).unsqueeze(0).half()
123+
video_file = load_video(video_file, self.processor, self.model_num_frames).unsqueeze(0).half()
122124
return key, video_file
123125

124126
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:

0 commit comments

Comments
 (0)