File tree Expand file tree Collapse file tree 3 files changed +39
-4
lines changed
Expand file tree Collapse file tree 3 files changed +39
-4
lines changed Original file line number Diff line number Diff line change 22
33#include < stdint.h>
44#ifdef AT_CUDA_ENABLED
5+ #include < cuda.h>
56#include < cuda_runtime.h>
67#include < cuda_fp16.h>
78#endif
@@ -24,7 +25,15 @@ template<typename To, typename From> To convert(From f) {
2425typedef struct AT_ALIGN (2 ) {
2526 unsigned short x;
2627#ifdef AT_CUDA_ENABLED
27- operator half () { return half { x }; }
28+ #if CUDA_VERSION < 9000
29+ operator half () { return half{ x }; }
30+ #else
31+ operator half () {
32+ __half_raw x_raw;
33+ x_raw.x = x;
34+ return half (x_raw);
35+ }
36+ #endif
2837#endif
2938 operator double ();
3039} Half;
@@ -41,11 +50,25 @@ inline Half::operator double() {
4150template <> half convert (double d);
4251#endif
4352
44-
4553template <typename To, typename From>
4654static inline To HalfFix (From h) {
4755 return To { h.x };
4856}
4957
58+ #ifdef AT_CUDA_ENABLED
59+ #if CUDA_VERSION >= 9000
60+ template <>
61+ inline __half HalfFix<__half, Half>(Half h) {
62+ __half_raw raw;
63+ raw.x = h.x ;
64+ return __half { raw };
65+ }
5066
67+ template <>
68+ inline Half HalfFix<Half, __half>(__half h) {
69+ __half_raw raw (h);
70+ return Half { raw.x };
71+ }
72+ #endif
73+ #endif
5174}
Original file line number Diff line number Diff line change @@ -24,7 +24,14 @@ template<> int64_t convert(Half f) {
2424
2525#ifdef AT_CUDA_ENABLED
2626template <> half convert (double d) {
27- return half { convert<Half,double >(d).x };
27+
28+ #if CUDA_VERSION < 9000
29+ return half {convert<Half,double >(d).x };
30+ #else
31+ __half_raw raw;
32+ raw.x = convert<Half,double >(d).x ;
33+ return half {raw};
34+ #endif
2835}
2936#endif
3037
Original file line number Diff line number Diff line change @@ -27,7 +27,12 @@ class Scalar {
2727#ifdef AT_CUDA_ENABLED
2828 Scalar (half vv)
2929 : tag(Tag::HAS_d) {
30- v.d = convert<double ,Half>(Half{vv.x });
30+ #if CUDA_VERSION < 9000
31+ v.d = convert<double , Half>(Half{vv.x });
32+ #else
33+ __half_raw vv_raw (vv);
34+ v.d = convert<double ,Half>(Half{vv_raw.x });
35+ #endif
3136 }
3237#endif
3338
You can’t perform that action at this time.
0 commit comments