Skip to content

Commit 0409b42

Browse files
committed
Merge commit '3abe5c80d2073f0e72f79b88f11b2a9d320fb116'
2 parents c39d48e + 3abe5c8 commit 0409b42

File tree

4 files changed

+136
-42
lines changed

4 files changed

+136
-42
lines changed

torch/lib/TH/THMath.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,20 @@ static inline double TH_lerp(double a, double b, double weight) {
1717
return a + weight * (b-a);
1818
}
1919

20-
#endif // _THMATH_H
20+
static inline float TH_sigmoidf(float value) {
21+
return 1.0f / (1.0f + expf(-value));
22+
}
23+
24+
static inline float TH_fracf(float x) {
25+
return x - truncf(x);
26+
}
27+
28+
static inline float TH_rsqrtf(float x) {
29+
return 1.0f / sqrtf(x);
30+
}
2131

32+
static inline float TH_lerpf(float a, float b, float weight) {
33+
return a + weight * (b-a);
34+
}
35+
36+
#endif // _THMATH_H

torch/lib/TH/cmake/FindSSE.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ SET(AVX2_CODE "
7373
7474
int main()
7575
{
76-
__m256i a;
76+
__m256i a = {0};
7777
a = _mm256_abs_epi16(a);
7878
return 0;
7979
}

torch/lib/TH/generic/THTensorCopy.c

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,84 @@
22
#define TH_GENERIC_FILE "generic/THTensorCopy.c"
33
#else
44

5+
int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
6+
const int MIN_SZ = 60 * 60;
7+
return THTensor_(isContiguous)(tensor) &&
8+
THTensor_(nDimension)(src) == 2 &&
9+
THTensor_(stride)(src, 0) == 1 &&
10+
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
11+
THTensor_(nElement)(tensor) >= MIN_SZ;
12+
}
13+
14+
// special case copy where tensor is contiguous and src is a transposed matrix
15+
// This can be generalized to most copies, but it's tricker
16+
void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
17+
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
18+
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
19+
20+
#ifdef TH_REAL_IS_BYTE
21+
const int BLOCK_SZ = 120;
22+
#else
23+
const int BLOCK_SZ = 60;
24+
#endif
25+
26+
THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
27+
real *sp = THTensor_(data)(src);
28+
real *rp = THTensor_(data)(tensor);
29+
real *bp = THTensor_(data)(buf);
30+
31+
long NR = THTensor_(size)(src, 0);
32+
long NC = THTensor_(size)(src, 1);
33+
for (long R = 0; R < NR; R += BLOCK_SZ) {
34+
for (long C = 0; C < NC; C += BLOCK_SZ) {
35+
real *spo = sp + R + C * NR;
36+
real *rpo = rp + C + R * NC;
37+
38+
int nr = MIN(NR - R, BLOCK_SZ);
39+
int nc = MIN(NC - C, BLOCK_SZ);
40+
41+
// 1. copy columns from src to buf
42+
for (int c = 0; c < nc; c++) {
43+
memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(real));
44+
}
45+
46+
// 2. transpose buf in place
47+
int rc_max = MAX(nr, nc);
48+
int rc_min = MIN(nr, nc);
49+
for (int r = 0; r < rc_max; r++) {
50+
int end = MIN(r, rc_min);
51+
for (int c = 0; c < end; c++) {
52+
real tmp = bp[r + BLOCK_SZ * c];
53+
bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
54+
bp[r * BLOCK_SZ + c] = tmp;
55+
}
56+
}
57+
58+
// 3. copy rows from buf to dst
59+
for (int r = 0; r < nr; r++) {
60+
memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(real));
61+
}
62+
}
63+
}
64+
THTensor_(free)(buf);
65+
#undef MIN
66+
#undef MAX
67+
}
68+
569
void THTensor_(copy)(THTensor *tensor, THTensor *src)
670
{
771
if (THTensor_(isContiguous)(tensor) && THTensor_(isContiguous)(src) && THTensor_(nElement)(tensor) == THTensor_(nElement)(src)) {
872
real *sp = THTensor_(data)(src);
973
real *rp = THTensor_(data)(tensor);
1074
ptrdiff_t sz = THTensor_(nElement)(tensor);
1175
#ifndef TH_REAL_IS_HALF
12-
THVector_(copy)(rp, sp, sz);
76+
THVector_(copy)(rp, sp, sz);
1377
#else
1478
memcpy(rp, sp, sz * sizeof(real));
79+
#endif
80+
#ifndef TH_REAL_IS_HALF
81+
} else if (THTensor_(copyTransposeValid)(tensor, src)) {
82+
THTensor_(copyTranspose)(tensor, src);
1583
#endif
1684
} else {
1785
TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;)

torch/lib/TH/generic/THTensorMath.c

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2746,43 +2746,50 @@ TENSOR_IMPLEMENT_LOGICAL_SUM(logicalany, ||, 0)
27462746
/* floating point only now */
27472747
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
27482748

2749-
LAB_IMPLEMENT_BASIC_FUNCTION(log,log)
2750-
LAB_IMPLEMENT_BASIC_FUNCTION(lgamma,lgamma)
2751-
LAB_IMPLEMENT_BASIC_FUNCTION(log1p,log1p)
2752-
LAB_IMPLEMENT_BASIC_FUNCTION(sigmoid,TH_sigmoid)
2753-
LAB_IMPLEMENT_BASIC_FUNCTION(exp,exp)
2754-
LAB_IMPLEMENT_BASIC_FUNCTION(cos,cos)
2755-
LAB_IMPLEMENT_BASIC_FUNCTION(acos,acos)
2756-
LAB_IMPLEMENT_BASIC_FUNCTION(cosh,cosh)
2757-
LAB_IMPLEMENT_BASIC_FUNCTION(sin,sin)
2758-
LAB_IMPLEMENT_BASIC_FUNCTION(asin,asin)
2759-
LAB_IMPLEMENT_BASIC_FUNCTION(sinh,sinh)
2760-
LAB_IMPLEMENT_BASIC_FUNCTION(tan,tan)
2761-
LAB_IMPLEMENT_BASIC_FUNCTION(atan,atan)
2762-
LAB_IMPLEMENT_BASIC_FUNCTION(tanh,tanh)
2763-
LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,pow)
2764-
LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,sqrt)
2765-
LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_rsqrt)
2766-
LAB_IMPLEMENT_BASIC_FUNCTION(ceil,ceil)
2767-
LAB_IMPLEMENT_BASIC_FUNCTION(floor,floor)
2768-
LAB_IMPLEMENT_BASIC_FUNCTION(round,round)
2769-
LAB_IMPLEMENT_BASIC_FUNCTION(abs,fabs)
2770-
LAB_IMPLEMENT_BASIC_FUNCTION(trunc,trunc)
2771-
LAB_IMPLEMENT_BASIC_FUNCTION(frac,TH_frac)
2749+
#if defined (TH_REAL_IS_FLOAT)
2750+
#define TH_MATH_NAME(fn) fn##f
2751+
#else
2752+
#define TH_MATH_NAME(fn) fn
2753+
#endif
2754+
2755+
LAB_IMPLEMENT_BASIC_FUNCTION(log,TH_MATH_NAME(log))
2756+
LAB_IMPLEMENT_BASIC_FUNCTION(lgamma,TH_MATH_NAME(lgamma))
2757+
LAB_IMPLEMENT_BASIC_FUNCTION(log1p,TH_MATH_NAME(log1p))
2758+
LAB_IMPLEMENT_BASIC_FUNCTION(sigmoid,TH_MATH_NAME(TH_sigmoid))
2759+
LAB_IMPLEMENT_BASIC_FUNCTION(exp,TH_MATH_NAME(exp))
2760+
LAB_IMPLEMENT_BASIC_FUNCTION(cos,TH_MATH_NAME(cos))
2761+
LAB_IMPLEMENT_BASIC_FUNCTION(acos,TH_MATH_NAME(acos))
2762+
LAB_IMPLEMENT_BASIC_FUNCTION(cosh,TH_MATH_NAME(cosh))
2763+
LAB_IMPLEMENT_BASIC_FUNCTION(sin,TH_MATH_NAME(sin))
2764+
LAB_IMPLEMENT_BASIC_FUNCTION(asin,TH_MATH_NAME(asin))
2765+
LAB_IMPLEMENT_BASIC_FUNCTION(sinh,TH_MATH_NAME(sinh))
2766+
LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan))
2767+
LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan))
2768+
LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh))
2769+
LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,TH_MATH_NAME(pow))
2770+
LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt))
2771+
LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt))
2772+
LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil))
2773+
LAB_IMPLEMENT_BASIC_FUNCTION(floor,TH_MATH_NAME(floor))
2774+
LAB_IMPLEMENT_BASIC_FUNCTION(round,TH_MATH_NAME(round))
2775+
LAB_IMPLEMENT_BASIC_FUNCTION(abs,TH_MATH_NAME(fabs))
2776+
LAB_IMPLEMENT_BASIC_FUNCTION(trunc,TH_MATH_NAME(trunc))
2777+
LAB_IMPLEMENT_BASIC_FUNCTION(frac,TH_MATH_NAME(TH_frac))
27722778
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
2773-
LAB_IMPLEMENT_BASIC_FUNCTION(cinv, 1.0 / )
2779+
LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / )
2780+
27742781

27752782
void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty)
27762783
{
27772784
THTensor_(resizeAs)(r_, tx);
2778-
TH_TENSOR_APPLY3(real, r_, real, tx, real, ty, *r__data = atan2(*tx_data,*ty_data););
2785+
TH_TENSOR_APPLY3(real, r_, real, tx, real, ty, *r__data = TH_MATH_NAME(atan2)(*tx_data,*ty_data););
27792786
}
27802787

27812788
void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight)
27822789
{
27832790
THArgCheck(THTensor_(nElement)(a) == THTensor_(nElement)(b), 2, "sizes do not match");
27842791
THTensor_(resizeAs)(r_, a);
2785-
TH_TENSOR_APPLY3(real, r_, real, a, real, b, *r__data = TH_lerp(*a_data, *b_data, weight););
2792+
TH_TENSOR_APPLY3(real, r_, real, a, real, b, *r__data = TH_MATH_NAME(TH_lerp)(*a_data, *b_data, weight););
27862793
}
27872794

27882795
void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim)
@@ -2823,15 +2830,15 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag, int keep
28232830
sum2 /= t_size;
28242831
sum2 -= sum*sum;
28252832
sum2 = (sum2 < 0 ? 0 : sum2);
2826-
*r__data = (real)sqrt(sum2);
2833+
*r__data = (real)TH_MATH_NAME(sqrt)(sum2);
28272834
}
28282835
else
28292836
{
28302837
sum /= t_size;
28312838
sum2 /= t_size-1;
28322839
sum2 -= ((real)t_size)/((real)(t_size-1))*sum*sum;
28332840
sum2 = (sum2 < 0 ? 0 : sum2);
2834-
*r__data = (real)sqrt(sum2);
2841+
*r__data = (real)TH_MATH_NAME(sqrt)(sum2);
28352842
});
28362843

28372844
if (!keepdim) {
@@ -2907,9 +2914,11 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
29072914
TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension,
29082915
accreal sum = 0;
29092916
long i;
2910-
for(i = 0; i < t_size; i++)
2911-
sum += pow(fabs(t_data[i*t_stride]), value);
2912-
*r__data = pow(sum, 1.0/value);)
2917+
for(i = 0; i < t_size; i++) {
2918+
sum += TH_MATH_NAME(pow)(
2919+
TH_MATH_NAME(fabs)(t_data[i*t_stride]), value);
2920+
}
2921+
*r__data = TH_MATH_NAME(pow)(sum, 1.0/value);)
29132922
}
29142923

29152924
if (!keepdim) {
@@ -2924,14 +2933,14 @@ accreal THTensor_(normall)(THTensor *tensor, real value)
29242933
TH_TENSOR_APPLY(real, tensor, sum += *tensor_data != 0.0;);
29252934
return sum;
29262935
} else if(value == 1) {
2927-
TH_TENSOR_APPLY(real, tensor, sum += fabs(*tensor_data););
2936+
TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(fabs)(*tensor_data););
29282937
return sum;
29292938
} else if(value == 2) {
29302939
TH_TENSOR_APPLY(real, tensor, accreal z = *tensor_data; sum += z*z;);
29312940
return sqrt(sum);
29322941
} else {
2933-
TH_TENSOR_APPLY(real, tensor, sum += pow(fabs(*tensor_data), value););
2934-
return pow(sum, 1.0/value);
2942+
TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*tensor_data), value););
2943+
return TH_MATH_NAME(pow)(sum, 1.0/value);
29352944
}
29362945
}
29372946

@@ -2963,7 +2972,7 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, real value, int dimension,
29632972
} else if (value == 2) {
29642973
TH_TENSOR_APPLY(real, rowS, accreal z = *rowS_data; norm += z*z;);
29652974
} else {
2966-
TH_TENSOR_APPLY(real, rowS, norm += pow(fabs(*rowS_data), value););
2975+
TH_TENSOR_APPLY(real, rowS, norm += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*rowS_data), value););
29672976
}
29682977

29692978
norm = pow(norm, 1/value);
@@ -2989,8 +2998,9 @@ accreal THTensor_(dist)(THTensor *tensor, THTensor *src, real value)
29892998
{
29902999
real sum = 0;
29913000
TH_TENSOR_APPLY2(real, tensor, real, src,
2992-
sum += pow(fabs(*tensor_data - *src_data), value);)
2993-
return pow(sum, 1.0/value);
3001+
sum += TH_MATH_NAME(pow)(
3002+
TH_MATH_NAME(fabs)(*tensor_data - *src_data), value););
3003+
return TH_MATH_NAME(pow)(sum, 1.0/value);
29943004
}
29953005

29963006
accreal THTensor_(meanall)(THTensor *tensor)
@@ -3048,12 +3058,12 @@ void THTensor_(logspace)(THTensor *r_, real a, real b, long n)
30483058

30493059
if(n == 1) {
30503060
TH_TENSOR_APPLY(real, r_,
3051-
*r__data = pow(10.0, a);
3061+
*r__data = TH_MATH_NAME(pow)(10.0, a);
30523062
i++;
30533063
);
30543064
} else {
30553065
TH_TENSOR_APPLY(real, r_,
3056-
*r__data = pow(10.0, a + i*(b-a)/((real)(n-1)));
3066+
*r__data = TH_MATH_NAME(pow)(10.0, a + i*(b-a)/((real)(n-1)));
30573067
i++;
30583068
);
30593069
}
@@ -3141,6 +3151,7 @@ void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, long nbins, real minval
31413151
);
31423152
}
31433153

3154+
#undef TH_MATH_NAME
31443155
#endif /* floating point only part */
31453156
#undef IS_NONZERO
31463157
#endif

0 commit comments

Comments
 (0)