@@ -96,6 +96,7 @@ def _after_fork(arg):
9696 if _initialized and _original_pid != os .getpid ():
9797 _initialized = False
9898 _in_bad_fork = True
99+ _CudaBase .__new__ = _lazy_new
99100
100101
101102_register_after_fork (_after_fork , _after_fork )
@@ -260,6 +261,14 @@ def init_err(self):
260261 torch ._C .__dict__ ['_CudaStreamBase' ] = _dummy_type ('CudaStreamBase' )
261262
262263
264+ @staticmethod
265+ def _lazy_new (cls , * args , ** kwargs ):
266+ _lazy_init ()
267+ # We need this method only for lazy init, so we can remove it
268+ del _CudaBase .__new__
269+ return super (_CudaBase , cls ).__new__ (cls , * args , ** kwargs )
270+
271+
263272class _CudaBase (object ):
264273 is_cuda = True
265274 is_sparse = False
@@ -268,11 +277,7 @@ def type(self, *args, **kwargs):
268277 with device (self .get_device ()):
269278 return super (_CudaBase , self ).type (* args , ** kwargs )
270279
271- def __new__ (cls , * args , ** kwargs ):
272- _lazy_init ()
273- # We need this method only for lazy init, so we can remove it
274- del _CudaBase .__new__
275- return super (_CudaBase , cls ).__new__ (cls , * args , ** kwargs )
280+ __new__ = _lazy_new
276281
277282
278283class DoubleStorage (_CudaBase , torch ._C .CudaDoubleStorageBase , _StorageBase ):
0 commit comments