Skip to content

Commit 7f6cd7c

Browse files
authored
Fix error message in CUDA forked subprocess (pytorch#1585)
We need to re-call _lazy_init in _CudaBase.__new__ in the subprocess.
1 parent 625850c commit 7f6cd7c

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

torch/cuda/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _after_fork(arg):
9696
if _initialized and _original_pid != os.getpid():
9797
_initialized = False
9898
_in_bad_fork = True
99+
_CudaBase.__new__ = _lazy_new
99100

100101

101102
_register_after_fork(_after_fork, _after_fork)
@@ -260,6 +261,14 @@ def init_err(self):
260261
torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
261262

262263

264+
@staticmethod
265+
def _lazy_new(cls, *args, **kwargs):
266+
_lazy_init()
267+
# We need this method only for lazy init, so we can remove it
268+
del _CudaBase.__new__
269+
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
270+
271+
263272
class _CudaBase(object):
264273
is_cuda = True
265274
is_sparse = False
@@ -268,11 +277,7 @@ def type(self, *args, **kwargs):
268277
with device(self.get_device()):
269278
return super(_CudaBase, self).type(*args, **kwargs)
270279

271-
def __new__(cls, *args, **kwargs):
272-
_lazy_init()
273-
# We need this method only for lazy init, so we can remove it
274-
del _CudaBase.__new__
275-
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
280+
__new__ = _lazy_new
276281

277282

278283
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):

0 commit comments

Comments
 (0)