1
1
import argparse
2
2
from io import BytesIO
3
3
import multiprocessing
4
+ from multiprocessing import Lock , Process , RawValue
4
5
from functools import partial
6
+ from multiprocessing .sharedctypes import RawValue
5
7
from PIL import Image
6
8
from tqdm import tqdm
7
9
from torchvision .transforms import functional as trans_fn
8
10
import os
9
11
from pathlib import Path
10
12
import lmdb
13
+ import numpy as np
14
+ import time
11
15
12
16
13
17
def resize_and_convert (img , size , resample ):
@@ -35,7 +39,6 @@ def resize_multiple(img, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=Fals
35
39
36
40
return [lr_img , hr_img , sr_img ]
37
41
38
-
39
42
def resize_worker (img_file , sizes , resample , lmdb_save = False ):
40
43
img = Image .open (img_file )
41
44
img = img .convert ('RGB' )
@@ -44,6 +47,55 @@ def resize_worker(img_file, sizes, resample, lmdb_save=False):
44
47
45
48
return img_file .name .split ('.' )[0 ], out
46
49
50
+ class WorkingContext ():
51
+ def __init__ (self , resize_fn , lmdb_save , out_path , env , sizes ):
52
+ self .resize_fn = resize_fn
53
+ self .lmdb_save = lmdb_save
54
+ self .out_path = out_path
55
+ self .env = env
56
+ self .sizes = sizes
57
+
58
+ self .counter = RawValue ('i' , 0 )
59
+ self .counter_lock = Lock ()
60
+
61
+ def inc_get (self ):
62
+ with self .counter_lock :
63
+ self .counter .value += 1
64
+ return self .counter .value
65
+
66
+ def value (self ):
67
+ with self .counter_lock :
68
+ return self .counter .value
69
+
70
+ def prepare_process_worker (wctx , file_subset ):
71
+ for file in file_subset :
72
+ i , imgs = wctx .resize_fn (file )
73
+ lr_img , hr_img , sr_img = imgs
74
+ if not wctx .lmdb_save :
75
+ lr_img .save (
76
+ '{}/lr_{}/{}.png' .format (wctx .out_path , wctx .sizes [0 ], i .zfill (5 )))
77
+ hr_img .save (
78
+ '{}/hr_{}/{}.png' .format (wctx .out_path , wctx .sizes [1 ], i .zfill (5 )))
79
+ sr_img .save (
80
+ '{}/sr_{}_{}/{}.png' .format (wctx .out_path , wctx .sizes [0 ], wctx .sizes [1 ], i .zfill (5 )))
81
+ else :
82
+ with wctx .env .begin (write = True ) as txn :
83
+ txn .put ('lr_{}_{}' .format (
84
+ wctx .sizes [0 ], i .zfill (5 )).encode ('utf-8' ), lr_img )
85
+ txn .put ('hr_{}_{}' .format (
86
+ wctx .sizes [1 ], i .zfill (5 )).encode ('utf-8' ), hr_img )
87
+ txn .put ('sr_{}_{}_{}' .format (
88
+ wctx .sizes [0 ], wctx .sizes [1 ], i .zfill (5 )).encode ('utf-8' ), sr_img )
89
+ curr_total = wctx .inc_get ()
90
+ if wctx .lmdb_save :
91
+ with wctx .env .begin (write = True ) as txn :
92
+ txn .put ('length' .encode ('utf-8' ), str (curr_total ).encode ('utf-8' ))
93
+
94
+ def all_threads_inactive (worker_threads ):
95
+ for thread in worker_threads :
96
+ if thread .is_alive ():
97
+ return False
98
+ return True
47
99
48
100
def prepare (img_path , out_path , n_worker , sizes = (16 , 128 ), resample = Image .BICUBIC , lmdb_save = False ):
49
101
resize_fn = partial (resize_worker , sizes = sizes ,
@@ -60,31 +112,29 @@ def prepare(img_path, out_path, n_worker, sizes=(16, 128), resample=Image.BICUBI
60
112
else :
61
113
env = lmdb .open (out_path , map_size = 1024 ** 4 , readahead = False )
62
114
63
- total = 0
64
- if n_worker > 1 :
65
- with multiprocessing .Pool (n_worker ) as pool :
66
- for i , imgs in tqdm (pool .imap_unordered (resize_fn , files )):
67
- lr_img , hr_img , sr_img = imgs
68
- if not lmdb_save :
69
- lr_img .save (
70
- '{}/lr_{}/{}.png' .format (out_path , sizes [0 ], i .zfill (5 )))
71
- hr_img .save (
72
- '{}/hr_{}/{}.png' .format (out_path , sizes [1 ], i .zfill (5 )))
73
- sr_img .save (
74
- '{}/sr_{}_{}/{}.png' .format (out_path , sizes [0 ], sizes [1 ], i .zfill (5 )))
75
- else :
76
- with env .begin (write = True ) as txn :
77
- txn .put ('lr_{}_{}' .format (
78
- sizes [0 ], i .zfill (5 )).encode ('utf-8' ), lr_img )
79
- txn .put ('hr_{}_{}' .format (
80
- sizes [1 ], i .zfill (5 )).encode ('utf-8' ), hr_img )
81
- txn .put ('sr_{}_{}_{}' .format (
82
- sizes [0 ], sizes [1 ], i .zfill (5 )).encode ('utf-8' ), sr_img )
83
- total += 1
84
- if lmdb_save :
85
- with env .begin (write = True ) as txn :
86
- txn .put ('length' .encode ('utf-8' ), str (total ).encode ('utf-8' ))
115
+ if n_worker > 1 :
116
+ # prepare data subsets
117
+ multi_env = None
118
+ if lmdb_save :
119
+ multi_env = env
120
+
121
+ file_subsets = np .array_split (files , n_worker )
122
+ worker_threads = []
123
+ wctx = WorkingContext (resize_fn , lmdb_save , out_path , multi_env , sizes )
124
+
125
+ # start worker processes, monitor results
126
+ for i in range (n_worker ):
127
+ proc = Process (target = prepare_process_worker , args = (wctx , file_subsets [i ]))
128
+ proc .start ()
129
+ worker_threads .append (proc )
130
+
131
+ total_count = str (len (files ))
132
+ while not all_threads_inactive (worker_threads ):
133
+ print ("{}/{} images processed" .format (wctx .value (), total_count ))
134
+ time .sleep (0.1 )
135
+
87
136
else :
137
+ total = 0
88
138
for file in tqdm (files ):
89
139
i , imgs = resize_fn (file )
90
140
lr_img , hr_img , sr_img = imgs
@@ -111,12 +161,12 @@ def prepare(img_path, out_path, n_worker, sizes=(16, 128), resample=Image.BICUBI
111
161
if __name__ == '__main__' :
112
162
parser = argparse .ArgumentParser ()
113
163
parser .add_argument ('--path' , '-p' , type = str ,
114
- default = '{}/Dataset/celebahq_256' . format ( Path . home ()) )
164
+ default = '../dataset/bunchofinputimgsfolder' )
115
165
parser .add_argument ('--out' , '-o' , type = str ,
116
- default = './dataset/celebahq' )
166
+ default = '.. /dataset/celebahq' )
117
167
118
- parser .add_argument ('--size' , type = str , default = '16,128 ' )
119
- parser .add_argument ('--n_worker' , type = int , default = 1 )
168
+ parser .add_argument ('--size' , type = str , default = '64,512 ' )
169
+ parser .add_argument ('--n_worker' , type = int , default = 3 )
120
170
parser .add_argument ('--resample' , type = str , default = 'bicubic' )
121
171
# default save in png format
122
172
parser .add_argument ('--lmdb' , '-l' , action = 'store_true' )
0 commit comments