@@ -16,30 +16,31 @@ def parallel_apply(modules, inputs, kwargs_tup=None):
1616 lock = threading .Lock ()
1717 results = {}
1818
19- def _worker (module , input , kwargs , results , lock ):
19+ def _worker (i , module , input , kwargs , results , lock ):
2020 var_input = input
2121 while not isinstance (var_input , Variable ):
2222 var_input = var_input [0 ]
2323 try :
2424 with torch .cuda .device_of (var_input ):
2525 output = module (* input , ** kwargs )
2626 with lock :
27- results [input ] = output
27+ results [i ] = output
2828 except Exception as e :
2929 with lock :
30- results [input ] = e
30+ results [i ] = e
3131
3232 threads = [threading .Thread (target = _worker ,
33- args = (module , input , kwargs , results , lock ),
33+ args = (i , module , input , kwargs , results , lock ),
3434 )
35- for module , input , kwargs in zip (modules , inputs , kwargs_tup )]
35+ for i , (module , input , kwargs ) in
36+ enumerate (zip (modules , inputs , kwargs_tup ))]
3637
3738 for thread in threads :
3839 thread .start ()
3940 for thread in threads :
4041 thread .join ()
4142 outputs = []
42- for i in inputs :
43+ for i in range ( len ( inputs )) :
4344 output = results [i ]
4445 if isinstance (output , Exception ):
4546 raise output
0 commit comments