11#pragma once
22
33#include < memory>
4+ #include < mutex>
45#include " ATen/Generator.h"
56#include " ATen/Type.h"
67#include " ATen/Utils.h"
@@ -13,29 +14,49 @@ class Context {
1314public:
1415 Context ();
1516 Type & getType (Backend p, ScalarType s) {
17+ initCUDAIfNeeded (p);
1618 auto & type = type_registry[static_cast <int >(p)][static_cast <int >(s)];
1719 if (!type)
1820 runtime_error (" %s%s%sType is not enabled." ,toString (p),toString (s));
1921 return *type;
2022 }
2123 Generator & defaultGenerator (Backend p) {
24+ initCUDAIfNeeded (p);
2225 auto & generator = generator_registry[static_cast <int >(p)];
2326 if (!generator)
2427 runtime_error (" %s backend type not enabled." ,toString (p));
2528 return *generator;
2629 }
2730 bool hasCUDA () const ;
31+ // defined in header so that getType has ability to inline
32+ // call_once check. getType is called fairly frequently
33+ THCState* lazyInitCUDA () {
34+ std::call_once (thc_init,[&] {
35+ doInitCUDA ();
36+ });
37+ return thc_state;
38+ }
2839 ~Context ();
2940 std::unique_ptr<Generator>
3041 generator_registry[static_cast <int >(Backend::NumOptions)];
3142 std::unique_ptr<Type> type_registry
3243 [static_cast <int >(Backend::NumOptions)]
3344 [static_cast <int >(ScalarType::NumOptions)];
3445 THCState * thc_state;
46+ private:
47+ void initCUDAIfNeeded (Backend p) {
48+ if (p == Backend::CUDA)
49+ lazyInitCUDA ();
50+ }
51+ void doInitCUDA ();
52+ std::once_flag thc_init;
3553};
3654
3755Context & globalContext ();
3856
57+ static inline void init () {
58+ globalContext ();
59+ }
3960
4061static inline Type& getType (Backend p, ScalarType s) {
4162 return globalContext ().getType (p,s);
0 commit comments