Skip to content

Commit 5b2aac7

Browse files
committed
Merge commit '224f5eabf5cfb3a19abc1819f7dac230500b6bdb'
2 parents fd490c6 + 224f5ea commit 5b2aac7

File tree

3 files changed

+29
-90
lines changed

3 files changed

+29
-90
lines changed

torch/lib/THC/THCGeneral.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,20 @@ void THCHeapUpdate(THCState *state, ptrdiff_t size) {
916916

917917
#include "THCStorage.c"
918918
#include "THCAllocator.c"
919+
920+
/* from THCHalf.h */
921+
922+
half THC_float2half(float f)
923+
{
924+
half h;
925+
TH_float2halfbits(&f, &h.x);
926+
return h;
927+
}
928+
929+
float THC_half2float(half h)
930+
{
931+
float f;
932+
TH_halfbits2float(&h.x, &f);
933+
return f;
934+
}
935+

torch/lib/THC/THCHalf.cu

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -33,96 +33,6 @@ void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len) {
3333
in, in + len, out, __half2floatOp());
3434
}
3535

36-
// FixMe: could call TH_half2float
37-
// and convert types here, but maybe slower?
38-
float THC_half2float(half h)
39-
{
40-
unsigned sign = ((h.x >> 15) & 1);
41-
unsigned exponent = ((h.x >> 10) & 0x1f);
42-
unsigned mantissa = ((h.x & 0x3ff) << 13);
43-
44-
if (exponent == 0x1f) { /* NaN or Inf */
45-
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
46-
exponent = 0xff;
47-
} else if (!exponent) { /* Denorm or Zero */
48-
if (mantissa) {
49-
unsigned int msb;
50-
exponent = 0x71;
51-
do {
52-
msb = (mantissa & 0x400000);
53-
mantissa <<= 1; /* normalize */
54-
--exponent;
55-
} while (!msb);
56-
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
57-
}
58-
} else {
59-
exponent += 0x70;
60-
}
61-
62-
int temp = ((sign << 31) | (exponent << 23) | mantissa);
63-
64-
float x;
65-
memcpy(&x,&temp,sizeof(float));
66-
return x;
67-
}
68-
69-
half THC_float2half(float f)
70-
{
71-
half ret;
72-
73-
unsigned x;
74-
memcpy(&x,&f,sizeof(f));
75-
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
76-
unsigned sign, exponent, mantissa;
77-
78-
// Get rid of +NaN/-NaN case first.
79-
if (u > 0x7f800000) {
80-
ret.x = 0x7fffU;
81-
return ret;
82-
}
83-
84-
sign = ((x >> 16) & 0x8000);
85-
86-
// Get rid of +Inf/-Inf, +0/-0.
87-
if (u > 0x477fefff) {
88-
ret.x = sign | 0x7c00U;
89-
return ret;
90-
}
91-
if (u < 0x33000001) {
92-
ret.x = (sign | 0x0000);
93-
return ret;
94-
}
95-
96-
exponent = ((u >> 23) & 0xff);
97-
mantissa = (u & 0x7fffff);
98-
99-
if (exponent > 0x70) {
100-
shift = 13;
101-
exponent -= 0x70;
102-
} else {
103-
shift = 0x7e - exponent;
104-
exponent = 0;
105-
mantissa |= 0x800000;
106-
}
107-
lsb = (1 << shift);
108-
lsb_s1 = (lsb >> 1);
109-
lsb_m1 = (lsb - 1);
110-
111-
// Round to nearest even.
112-
remainder = (mantissa & lsb_m1);
113-
mantissa >>= shift;
114-
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
115-
++mantissa;
116-
if (!(mantissa & 0x3ff)) {
117-
++exponent;
118-
mantissa = 0;
119-
}
120-
}
121-
122-
ret.x = (sign | (exponent << 10) | mantissa);
123-
return ret;
124-
}
125-
12636
THC_EXTERNC int THC_nativeHalfInstructions(THCState *state) {
12737
cudaDeviceProp* prop =
12838
THCState_getCurrentDeviceProperties(state);

torch/lib/THC/THCTensorTopK.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,33 @@ struct TopKTypeConfig<double> {
112112
}
113113
};
114114

115+
#ifdef CUDA_HALF_TENSOR
115116
template <>
116117
struct TopKTypeConfig<half> {
117118
typedef unsigned int RadixType;
118119

119120
static inline __device__ RadixType convert(half v) {
121+
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
120122
RadixType x = __half_as_ushort(v);
121123
RadixType mask = -((x >> 15)) | 0x8000;
122124
return (x ^ mask);
125+
#else
126+
assert(false);
127+
return 0u;
128+
#endif
123129
}
124130

125131
static inline __device__ half deconvert(RadixType v) {
132+
#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000
126133
RadixType mask = ((v >> 15) - 1) | 0x8000;
127134
return __ushort_as_half(v ^ mask);
135+
#else
136+
assert(false);
137+
return ScalarConvert<int, half>::to(0);
138+
#endif
128139
}
129140
};
141+
#endif // CUDA_HALF_TENSOR
130142

131143
// This function counts the distribution of all input values in a
132144
// slice we are selecting by radix digit at `radixDigitPos`, but only

0 commit comments

Comments
 (0)