Skip to content

Commit 274b5c9

Browse files
nhynesapaszke
authored andcommitted
Allow unhashable inputs to parallel_apply
1 parent dfa2d26 commit 274b5c9

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

torch/nn/parallel/parallel_apply.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,31 @@ def parallel_apply(modules, inputs, kwargs_tup=None):
1616
lock = threading.Lock()
1717
results = {}
1818

19-
def _worker(module, input, kwargs, results, lock):
19+
def _worker(i, module, input, kwargs, results, lock):
2020
var_input = input
2121
while not isinstance(var_input, Variable):
2222
var_input = var_input[0]
2323
try:
2424
with torch.cuda.device_of(var_input):
2525
output = module(*input, **kwargs)
2626
with lock:
27-
results[input] = output
27+
results[i] = output
2828
except Exception as e:
2929
with lock:
30-
results[input] = e
30+
results[i] = e
3131

3232
threads = [threading.Thread(target=_worker,
33-
args=(module, input, kwargs, results, lock),
33+
args=(i, module, input, kwargs, results, lock),
3434
)
35-
for module, input, kwargs in zip(modules, inputs, kwargs_tup)]
35+
for i, (module, input, kwargs) in
36+
enumerate(zip(modules, inputs, kwargs_tup))]
3637

3738
for thread in threads:
3839
thread.start()
3940
for thread in threads:
4041
thread.join()
4142
outputs = []
42-
for i in inputs:
43+
for i in range(len(inputs)):
4344
output = results[i]
4445
if isinstance(output, Exception):
4546
raise output

0 commit comments

Comments
 (0)