Skip to content

Commit f0788af

Browse files
committed
lazily initialize cuda so that we behave similar to PyTorch
1 parent a4dc7dc commit f0788af

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

Context.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@ static inline void argErrorHandler(int arg, const char * msg, void * data) {
2020
throw std::runtime_error(new_error.str());
2121
}
2222

23-
Context::Context() {
23+
Context::Context()
24+
: thc_state(nullptr) {
2425

2526
THSetDefaultErrorHandler(errorHandler,nullptr);
2627
THSetDefaultArgErrorHandler(argErrorHandler,nullptr);
2728

29+
generator_registry[static_cast<int>(Backend::CPU)]
30+
.reset(new CPUGenerator(this));
31+
Type::registerAll(this);
32+
}
33+
void Context::doInitCUDA() {
2834
#ifdef AT_CUDA_ENABLED
2935
thc_state = THCState_alloc();
3036
THCState_setDeviceAllocator(thc_state, THCCachingAllocator_get());
@@ -33,15 +39,11 @@ Context::Context() {
3339
generator_registry[static_cast<int>(Backend::CUDA)]
3440
.reset(new CUDAGenerator(this));
3541
#endif
36-
37-
generator_registry[static_cast<int>(Backend::CPU)]
38-
.reset(new CPUGenerator(this));
39-
Type::registerAll(this);
4042
}
41-
4243
Context::~Context() {
4344
#ifdef AT_CUDA_ENABLED
44-
THCState_free(thc_state);
45+
if(thc_state)
46+
THCState_free(thc_state);
4547
#endif
4648
}
4749

Context.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <memory>
4+
#include <mutex>
45
#include "ATen/Generator.h"
56
#include "ATen/Type.h"
67
#include "ATen/Utils.h"
@@ -13,29 +14,49 @@ class Context {
1314
public:
1415
Context();
1516
Type & getType(Backend p, ScalarType s) {
17+
initCUDAIfNeeded(p);
1618
auto & type = type_registry[static_cast<int>(p)][static_cast<int>(s)];
1719
if(!type)
1820
runtime_error("%s%s%sType is not enabled.",toString(p),toString(s));
1921
return *type;
2022
}
2123
Generator & defaultGenerator(Backend p) {
24+
initCUDAIfNeeded(p);
2225
auto & generator = generator_registry[static_cast<int>(p)];
2326
if(!generator)
2427
runtime_error("%s backend type not enabled.",toString(p));
2528
return *generator;
2629
}
2730
bool hasCUDA() const;
31+
// defined in header so that getType has ability to inline
32+
// call_once check. getType is called fairly frequently
33+
THCState* lazyInitCUDA() {
34+
std::call_once(thc_init,[&] {
35+
doInitCUDA();
36+
});
37+
return thc_state;
38+
}
2839
~Context();
2940
std::unique_ptr<Generator>
3041
generator_registry[static_cast<int>(Backend::NumOptions)];
3142
std::unique_ptr<Type> type_registry
3243
[static_cast<int>(Backend::NumOptions)]
3344
[static_cast<int>(ScalarType::NumOptions)];
3445
THCState * thc_state;
46+
private:
47+
void initCUDAIfNeeded(Backend p) {
48+
if(p == Backend::CUDA)
49+
lazyInitCUDA();
50+
}
51+
void doInitCUDA();
52+
std::once_flag thc_init;
3553
};
3654

3755
Context & globalContext();
3856

57+
static inline void init() {
58+
globalContext();
59+
}
3960

4061
static inline Type& getType(Backend p, ScalarType s) {
4162
return globalContext().getType(p,s);

0 commit comments

Comments
 (0)