Skip to content

Commit 7d3511f

Browse files
csarofeensoumith
authored andcommitted
Half fixes for ATen and CUDA 9.0
1 parent 25b591e commit 7d3511f

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

Half.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include<stdint.h>
44
#ifdef AT_CUDA_ENABLED
5+
#include <cuda.h>
56
#include <cuda_runtime.h>
67
#include <cuda_fp16.h>
78
#endif
@@ -24,7 +25,15 @@ template<typename To, typename From> To convert(From f) {
2425
typedef struct AT_ALIGN(2) {
2526
unsigned short x;
2627
#ifdef AT_CUDA_ENABLED
27-
operator half() { return half { x }; }
28+
#if CUDA_VERSION < 9000
29+
operator half() { return half{ x }; }
30+
#else
31+
operator half() {
32+
__half_raw x_raw;
33+
x_raw.x = x;
34+
return half(x_raw);
35+
}
36+
#endif
2837
#endif
2938
operator double();
3039
} Half;
@@ -41,11 +50,25 @@ inline Half::operator double() {
4150
template<> half convert(double d);
4251
#endif
4352

44-
4553
template<typename To, typename From>
4654
static inline To HalfFix(From h) {
4755
return To { h.x };
4856
}
4957

58+
#ifdef AT_CUDA_ENABLED
59+
#if CUDA_VERSION >= 9000
60+
template<>
61+
inline __half HalfFix<__half, Half>(Half h) {
62+
__half_raw raw;
63+
raw.x = h.x;
64+
return __half { raw };
65+
}
5066

67+
template<>
68+
inline Half HalfFix<Half, __half>(__half h) {
69+
__half_raw raw(h);
70+
return Half { raw.x };
71+
}
72+
#endif
73+
#endif
5174
}

Scalar.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ template<> int64_t convert(Half f) {
2424

2525
#ifdef AT_CUDA_ENABLED
2626
template<> half convert(double d) {
27-
return half { convert<Half,double>(d).x };
27+
28+
#if CUDA_VERSION < 9000
29+
return half {convert<Half,double>(d).x};
30+
#else
31+
__half_raw raw;
32+
raw.x = convert<Half,double>(d).x;
33+
return half {raw};
34+
#endif
2835
}
2936
#endif
3037

Scalar.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ class Scalar {
2727
#ifdef AT_CUDA_ENABLED
2828
Scalar(half vv)
2929
: tag(Tag::HAS_d) {
30-
v.d = convert<double,Half>(Half{vv.x});
30+
#if CUDA_VERSION < 9000
31+
v.d = convert<double, Half>(Half{vv.x});
32+
#else
33+
__half_raw vv_raw(vv);
34+
v.d = convert<double,Half>(Half{vv_raw.x});
35+
#endif
3136
}
3237
#endif
3338

0 commit comments

Comments
 (0)