Skip to content

Commit 30a6f23

Browse files
committed
Update prepare_data.py
memory leak fix #28 note: someone should probably implement a better progress indicator
1 parent 5fa9a07 commit 30a6f23

File tree

1 file changed

+79
-29
lines changed

1 file changed

+79
-29
lines changed

data/prepare_data.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import argparse
22
from io import BytesIO
33
import multiprocessing
4+
from multiprocessing import Lock, Process, RawValue
45
from functools import partial
6+
from multiprocessing.sharedctypes import RawValue
57
from PIL import Image
68
from tqdm import tqdm
79
from torchvision.transforms import functional as trans_fn
810
import os
911
from pathlib import Path
1012
import lmdb
13+
import numpy as np
14+
import time
1115

1216

1317
def resize_and_convert(img, size, resample):
@@ -35,7 +39,6 @@ def resize_multiple(img, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=Fals
3539

3640
return [lr_img, hr_img, sr_img]
3741

38-
3942
def resize_worker(img_file, sizes, resample, lmdb_save=False):
4043
img = Image.open(img_file)
4144
img = img.convert('RGB')
@@ -44,6 +47,55 @@ def resize_worker(img_file, sizes, resample, lmdb_save=False):
4447

4548
return img_file.name.split('.')[0], out
4649

50+
class WorkingContext():
51+
def __init__(self, resize_fn, lmdb_save, out_path, env, sizes):
52+
self.resize_fn = resize_fn
53+
self.lmdb_save = lmdb_save
54+
self.out_path = out_path
55+
self.env = env
56+
self.sizes = sizes
57+
58+
self.counter = RawValue('i', 0)
59+
self.counter_lock = Lock()
60+
61+
def inc_get(self):
62+
with self.counter_lock:
63+
self.counter.value += 1
64+
return self.counter.value
65+
66+
def value(self):
67+
with self.counter_lock:
68+
return self.counter.value
69+
70+
def prepare_process_worker(wctx, file_subset):
71+
for file in file_subset:
72+
i, imgs = wctx.resize_fn(file)
73+
lr_img, hr_img, sr_img = imgs
74+
if not wctx.lmdb_save:
75+
lr_img.save(
76+
'{}/lr_{}/{}.png'.format(wctx.out_path, wctx.sizes[0], i.zfill(5)))
77+
hr_img.save(
78+
'{}/hr_{}/{}.png'.format(wctx.out_path, wctx.sizes[1], i.zfill(5)))
79+
sr_img.save(
80+
'{}/sr_{}_{}/{}.png'.format(wctx.out_path, wctx.sizes[0], wctx.sizes[1], i.zfill(5)))
81+
else:
82+
with wctx.env.begin(write=True) as txn:
83+
txn.put('lr_{}_{}'.format(
84+
wctx.sizes[0], i.zfill(5)).encode('utf-8'), lr_img)
85+
txn.put('hr_{}_{}'.format(
86+
wctx.sizes[1], i.zfill(5)).encode('utf-8'), hr_img)
87+
txn.put('sr_{}_{}_{}'.format(
88+
wctx.sizes[0], wctx.sizes[1], i.zfill(5)).encode('utf-8'), sr_img)
89+
curr_total = wctx.inc_get()
90+
if wctx.lmdb_save:
91+
with wctx.env.begin(write=True) as txn:
92+
txn.put('length'.encode('utf-8'), str(curr_total).encode('utf-8'))
93+
94+
def all_threads_inactive(worker_threads):
95+
for thread in worker_threads:
96+
if thread.is_alive():
97+
return False
98+
return True
4799

48100
def prepare(img_path, out_path, n_worker, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=False):
49101
resize_fn = partial(resize_worker, sizes=sizes,
@@ -60,31 +112,29 @@ def prepare(img_path, out_path, n_worker, sizes=(16, 128), resample=Image.BICUBI
60112
else:
61113
env = lmdb.open(out_path, map_size=1024 ** 4, readahead=False)
62114

63-
total = 0
64-
if n_worker>1:
65-
with multiprocessing.Pool(n_worker) as pool:
66-
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
67-
lr_img, hr_img, sr_img = imgs
68-
if not lmdb_save:
69-
lr_img.save(
70-
'{}/lr_{}/{}.png'.format(out_path, sizes[0], i.zfill(5)))
71-
hr_img.save(
72-
'{}/hr_{}/{}.png'.format(out_path, sizes[1], i.zfill(5)))
73-
sr_img.save(
74-
'{}/sr_{}_{}/{}.png'.format(out_path, sizes[0], sizes[1], i.zfill(5)))
75-
else:
76-
with env.begin(write=True) as txn:
77-
txn.put('lr_{}_{}'.format(
78-
sizes[0], i.zfill(5)).encode('utf-8'), lr_img)
79-
txn.put('hr_{}_{}'.format(
80-
sizes[1], i.zfill(5)).encode('utf-8'), hr_img)
81-
txn.put('sr_{}_{}_{}'.format(
82-
sizes[0], sizes[1], i.zfill(5)).encode('utf-8'), sr_img)
83-
total += 1
84-
if lmdb_save:
85-
with env.begin(write=True) as txn:
86-
txn.put('length'.encode('utf-8'), str(total).encode('utf-8'))
115+
if n_worker > 1:
116+
# prepare data subsets
117+
multi_env = None
118+
if lmdb_save:
119+
multi_env = env
120+
121+
file_subsets = np.array_split(files, n_worker)
122+
worker_threads = []
123+
wctx = WorkingContext(resize_fn, lmdb_save, out_path, multi_env, sizes)
124+
125+
# start worker processes, monitor results
126+
for i in range(n_worker):
127+
proc = Process(target=prepare_process_worker, args=(wctx, file_subsets[i]))
128+
proc.start()
129+
worker_threads.append(proc)
130+
131+
total_count = str(len(files))
132+
while not all_threads_inactive(worker_threads):
133+
print("{}/{} images processed".format(wctx.value(), total_count))
134+
time.sleep(0.1)
135+
87136
else:
137+
total = 0
88138
for file in tqdm(files):
89139
i, imgs = resize_fn(file)
90140
lr_img, hr_img, sr_img = imgs
@@ -111,12 +161,12 @@ def prepare(img_path, out_path, n_worker, sizes=(16, 128), resample=Image.BICUBI
111161
if __name__ == '__main__':
112162
parser = argparse.ArgumentParser()
113163
parser.add_argument('--path', '-p', type=str,
114-
default='{}/Dataset/celebahq_256'.format(Path.home()))
164+
default='../dataset/bunchofinputimgsfolder')
115165
parser.add_argument('--out', '-o', type=str,
116-
default='./dataset/celebahq')
166+
default='../dataset/celebahq')
117167

118-
parser.add_argument('--size', type=str, default='16,128')
119-
parser.add_argument('--n_worker', type=int, default=1)
168+
parser.add_argument('--size', type=str, default='64,512')
169+
parser.add_argument('--n_worker', type=int, default=3)
120170
parser.add_argument('--resample', type=str, default='bicubic')
121171
# default save in png format
122172
parser.add_argument('--lmdb', '-l', action='store_true')

0 commit comments

Comments
 (0)