Skip to content

Commit 194d740

Browse files
committed
Merge commit '5f308b50fb558a620253443ef45f7cf3a91be410'
2 parents 0d53824 + 5f308b5 commit 194d740

File tree

8 files changed

+67
-28
lines changed

8 files changed

+67
-28
lines changed

torch/lib/THC/THCReduce.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ bool THC_reduceDim(THCState* state,
168168
const ModifyOp& modifyOp,
169169
const ReduceOp& reduceOp,
170170
typename TensorUtils<TensorType>::DataType init,
171-
int dim) {
171+
int dim,
172+
int keepdim) {
172173
ptrdiff_t inElements = TensorUtils<TensorType>::getNumElements(state, in);
173174

174175
long reductionSize = TensorUtils<TensorType>::getSize(state, in, dim);
@@ -315,6 +316,10 @@ bool THC_reduceDim(THCState* state,
315316
#undef HANDLE_IN_CASE
316317
#undef HANDLE_OUT_CASE
317318

319+
320+
if (!keepdim) {
321+
TensorUtils<TensorType>::squeeze1d(state, out, out, dim);
322+
}
318323
return true;
319324
}
320325

torch/lib/THC/THCTensorMathReduce.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ THC_reduceDimIndex(THCState *state,
625625
TensorTypeIndex *tgt2_,
626626
TensorTypeK *src,
627627
long dimension,
628+
int keepdim,
628629
const thrust::pair<
629630
typename TensorUtils<TensorTypeK>::DataType,
630631
typename TensorUtils<TensorTypeIndex>::DataType>& init,
@@ -653,6 +654,10 @@ THC_reduceDimIndex(THCState *state,
653654
TensorUtils<TensorTypeK>::free(state, src);
654655
TensorUtils<TensorTypeK>::freeCopyTo(state, tgt1, tgt1_);
655656
TensorUtils<TensorTypeIndex>::freeCopyTo(state, tgt2, tgt2_);
657+
if (!keepdim) {
658+
TensorUtils<TensorTypeK>::squeeze1d(state, tgt1_, tgt1_, dimension);
659+
TensorUtils<TensorTypeIndex>::squeeze1d(state, tgt2_, tgt2_, dimension);
660+
}
656661
}
657662

658663
template <typename T, typename Index>

torch/lib/THC/THCTensorTypeUtils.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ TensorUtils<TENSOR_TYPE>::resizeAs(THCState* state, \
7373
TENSOR_TYPE##_resizeAs(state, dst, src); \
7474
} \
7575
\
76+
void \
77+
TensorUtils<TENSOR_TYPE>::squeeze1d(THCState *state, \
78+
TENSOR_TYPE *dst, \
79+
TENSOR_TYPE *src, \
80+
int dimension) { \
81+
TENSOR_TYPE##_squeeze1d(state, dst, src, dimension); \
82+
} \
83+
\
7684
DATA_TYPE* \
7785
TensorUtils<TENSOR_TYPE>::getData(THCState* state, \
7886
TENSOR_TYPE* t) { \

torch/lib/THC/THCTensorTypeUtils.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ struct TensorUtils {
4949
THLongStorage* strides); \
5050
static void resizeAs(THCState* state, TENSOR_TYPE* dst, \
5151
TENSOR_TYPE* src); \
52+
static void squeeze1d(THCState *state, TENSOR_TYPE *dst, \
53+
TENSOR_TYPE *src, int dimension); \
5254
static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t); \
5355
static ptrdiff_t getNumElements(THCState* state, TENSOR_TYPE* t); \
5456
static long getSize(THCState* state, TENSOR_TYPE* t, int dim); \

torch/lib/THC/generic/THCTensorMathReduce.cu

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,40 @@
33
#else
44

55
THC_API void
6-
THCTensor_(sum)(THCState* state, THCTensor *self, THCTensor *src, long dimension) {
6+
THCTensor_(sum)(THCState* state, THCTensor *self, THCTensor *src, long dimension, int keepdim) {
77
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
88
if (!THC_reduceDim(state, self, src,
99
thrust::identity<real>(),
1010
ReduceAdd<real, real>(),
1111
ScalarConvert<int, real>::to(0),
12-
dimension)) {
12+
dimension,
13+
keepdim)) {
1314
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
1415
}
1516

1617
THCudaCheck(cudaGetLastError());
1718
}
1819

1920
THC_API void
20-
THCTensor_(prod)(THCState* state, THCTensor *self, THCTensor *src, long dimension) {
21+
THCTensor_(prod)(THCState* state, THCTensor *self, THCTensor *src, long dimension, int keepdim) {
2122
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
2223
if (!THC_reduceDim(state, self, src,
2324
thrust::identity<real>(),
2425
ReduceMultiply<real, real>(),
2526
ScalarConvert<int, real>::to(1),
26-
dimension)) {
27+
dimension,
28+
keepdim)) {
2729
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
2830
}
2931

3032
THCudaCheck(cudaGetLastError());
3133
}
3234

3335
THC_API void
34-
THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim)
36+
THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim, int keepdim)
3537
{
3638
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
37-
THCTensor_(sum)(state, self, src, dim);
39+
THCTensor_(sum)(state, self, src, dim, keepdim);
3840
THCTensor_(div)(state, self, self, ScalarConvert<long, real>::to(THCTensor_(size)(state, src, dim)));
3941
}
4042

@@ -70,7 +72,7 @@ THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value,
7072
}
7173

7274
THC_API void
73-
THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag)
75+
THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag, int keepdim)
7476
{
7577
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
7678
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
@@ -89,10 +91,14 @@ THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimensio
8991

9092
THCTensor_(free)(state, src);
9193
THCTensor_(freeCopyTo)(state, self, self_);
94+
95+
if (!keepdim) {
96+
THCTensor_(squeeze1d)(state, self_, self_, dimension);
97+
}
9298
}
9399

94100
THC_API void
95-
THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag)
101+
THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag, int keepdim)
96102
{
97103
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
98104
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
@@ -111,6 +117,10 @@ THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, long dimensio
111117

112118
THCTensor_(free)(state, src);
113119
THCTensor_(freeCopyTo)(state, self, self_);
120+
121+
if (!keepdim) {
122+
THCTensor_(squeeze1d)(state, self_, self_, dimension);
123+
}
114124
}
115125

116126
THC_API accreal
@@ -146,28 +156,28 @@ THCTensor_(varall)(THCState *state, THCTensor *self)
146156
}
147157

148158
THC_API void
149-
THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension)
159+
THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, int keepdim)
150160
{
151161
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
152162
if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) {
153163
THC_reduceDim(state, self, src,
154164
TensorNonZeroOp<real>(), ReduceAdd<real, real>(),
155-
ScalarConvert<float, real>::to(0.0), dimension);
165+
ScalarConvert<float, real>::to(0.0), dimension, keepdim);
156166
} else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) {
157167
THC_reduceDim(state, self, src,
158168
TensorNormOp<real, 1>(value), ReduceAdd<real, real>(),
159-
ScalarConvert<float, real>::to(0.0), dimension);
169+
ScalarConvert<float, real>::to(0.0), dimension, keepdim);
160170

161171
} else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) {
162172
THC_reduceDim(state, self, src,
163173
TensorNormOp<real, 2>(value), ReduceAdd<real, real>(),
164-
ScalarConvert<float, real>::to(0.0), dimension);
174+
ScalarConvert<float, real>::to(0.0), dimension, keepdim);
165175
THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5));
166176

167177
} else {
168178
THC_reduceDim(state, self, src,
169179
TensorNormOp<real, -1>(value), ReduceAdd<real, real>(),
170-
ScalarConvert<float, real>::to(0.0), dimension);
180+
ScalarConvert<float, real>::to(0.0), dimension, keepdim);
171181
THCTensor_(pow)(state, self, self, THCNumerics<real>::cinv(value));
172182
}
173183

@@ -325,7 +335,8 @@ THCTensor_(max)(THCState *state,
325335
THCTensor *values,
326336
THCudaLongTensor *indices,
327337
THCTensor *src,
328-
long dimension) {
338+
long dimension,
339+
int keepdim) {
329340
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, values, indices, src));
330341

331342
thrust::pair<typename TensorUtils<THCTensor>::DataType, long>
@@ -334,7 +345,7 @@ THCTensor_(max)(THCState *state,
334345
THCNumerics<typename TensorUtils<THCTensor>::DataType>::min(), 1);
335346

336347
return THC_reduceDimIndex(
337-
state, values, indices, src, dimension, init,
348+
state, values, indices, src, dimension, keepdim, init,
338349
MaxValuePair<typename TensorUtils<THCTensor>::DataType, long>());
339350
}
340351

@@ -343,7 +354,8 @@ THCTensor_(min)(THCState *state,
343354
THCTensor *values,
344355
THCudaLongTensor *indices,
345356
THCTensor *src,
346-
long dimension) {
357+
long dimension,
358+
int keepdim) {
347359
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, values, indices, src));
348360

349361
thrust::pair<typename TensorUtils<THCTensor>::DataType, long>
@@ -352,7 +364,7 @@ THCTensor_(min)(THCState *state,
352364
THCNumerics<typename TensorUtils<THCTensor>::DataType>::max(), 1);
353365

354366
return THC_reduceDimIndex(
355-
state, values, indices, src, dimension, init,
367+
state, values, indices, src, dimension, keepdim, init,
356368
MinValuePair<typename TensorUtils<THCTensor>::DataType, long>());
357369
}
358370

torch/lib/THC/generic/THCTensorMathReduce.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
66

77
THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real max_norm);
8-
THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag);
9-
THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension);
10-
THC_API void THCTensor_(var)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag);
8+
THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag, int keepdim);
9+
THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, int keepdim);
10+
THC_API void THCTensor_(var)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag, int keepdim);
1111

1212
THC_API accreal THCTensor_(stdall)(THCState *state, THCTensor *self);
1313
THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value);
1414
THC_API accreal THCTensor_(varall)(THCState *state, THCTensor *self);
1515

1616
#endif
1717

18-
THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim);
19-
THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim);
20-
THC_API void THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim);
18+
THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim, int keepdim);
19+
THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim, int keepdim);
20+
THC_API void THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim, int keepdim);
2121

2222
THC_API accreal THCTensor_(sumall)(THCState *state, THCTensor *self);
2323
THC_API accreal THCTensor_(prodall)(THCState *state, THCTensor *self);
@@ -26,11 +26,11 @@ THC_API accreal THCTensor_(meanall)(THCState *state, THCTensor *self);
2626
THC_API void THCTensor_(min)(THCState *state,
2727
THCTensor *values,
2828
THCudaLongTensor *indices,
29-
THCTensor *src, long dim);
29+
THCTensor *src, long dim, int keepdim);
3030
THC_API void THCTensor_(max)(THCState *state,
3131
THCTensor *values,
3232
THCudaLongTensor *indices,
33-
THCTensor *src, long dim);
33+
THCTensor *src, long dim, int keepdim);
3434

3535
THC_API real THCTensor_(minall)(THCState *state, THCTensor *self);
3636
THC_API real THCTensor_(maxall)(THCState *state, THCTensor *self);

torch/lib/THC/generic/THCTensorMode.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ THC_API void THCTensor_(mode)(THCState *state,
159159
THCTensor *values,
160160
THCudaLongTensor *indices,
161161
THCTensor *input,
162-
int dimension) {
162+
int dimension,
163+
int keepdim) {
163164
THLongStorage *dim;
164165
THCTensor *transposed, *contiguous, *valuesTransposed;
165166
THLongStorage *position;
@@ -301,6 +302,11 @@ THC_API void THCTensor_(mode)(THCState *state,
301302
THCudaLongTensor_free(state, indicesTransposed);
302303
THCudaLongStorage_free(state, sortBuffer);
303304
}
305+
306+
if (!keepdim) {
307+
THCTensor_(squeeze1d)(state, values, values, dimension);
308+
THCudaLongTensor_squeeze1d(state, indices, indices, dimension);
309+
}
304310
}
305311

306312
#undef MAX_GRID_SIZE

torch/lib/THC/generic/THCTensorMode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ THC_API void THCTensor_(mode)(THCState *state,
88
THCTensor *values,
99
THCudaLongTensor *indices,
1010
THCTensor *input,
11-
int dimension);
11+
int dimension,
12+
int keepdim);
1213

1314
#endif // THC_GENERIC_FILE

0 commit comments

Comments
 (0)