33#else
44
55THC_API void
6- THCTensor_ (sum)(THCState* state, THCTensor *self, THCTensor *src, long dimension) {
6+ THCTensor_ (sum)(THCState* state, THCTensor *self, THCTensor *src, long dimension, int keepdim ) {
77 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , self, src));
88 if (!THC_reduceDim (state, self, src,
99 thrust::identity<real>(),
1010 ReduceAdd<real, real>(),
1111 ScalarConvert<int , real>::to (0 ),
12- dimension)) {
12+ dimension,
13+ keepdim)) {
1314 THArgCheck (false , 2 , CUTORCH_DIM_WARNING);
1415 }
1516
1617 THCudaCheck (cudaGetLastError ());
1718}
1819
1920THC_API void
20- THCTensor_ (prod)(THCState* state, THCTensor *self, THCTensor *src, long dimension) {
21+ THCTensor_ (prod)(THCState* state, THCTensor *self, THCTensor *src, long dimension, int keepdim ) {
2122 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , self, src));
2223 if (!THC_reduceDim (state, self, src,
2324 thrust::identity<real>(),
2425 ReduceMultiply<real, real>(),
2526 ScalarConvert<int , real>::to (1 ),
26- dimension)) {
27+ dimension,
28+ keepdim)) {
2729 THArgCheck (false , 2 , CUTORCH_DIM_WARNING);
2830 }
2931
3032 THCudaCheck (cudaGetLastError ());
3133}
3234
3335THC_API void
34- THCTensor_ (mean)(THCState *state, THCTensor *self, THCTensor *src, long dim)
36+ THCTensor_ (mean)(THCState *state, THCTensor *self, THCTensor *src, long dim, int keepdim )
3537{
3638 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , self, src));
37- THCTensor_ (sum)(state, self, src, dim);
39+ THCTensor_ (sum)(state, self, src, dim, keepdim );
3840 THCTensor_ (div)(state, self, self, ScalarConvert<long , real>::to (THCTensor_ (size)(state, src, dim)));
3941}
4042
@@ -70,7 +72,7 @@ THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value,
7072}
7173
7274THC_API void
73- THCTensor_ (std)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag)
75+ THCTensor_ (std)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag, int keepdim )
7476{
7577 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , self_, src));
7678 THLongStorage *dim = THCTensor_ (newSizeOf)(state, src);
@@ -89,10 +91,14 @@ THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimensio
8991
9092 THCTensor_ (free )(state, src);
9193 THCTensor_ (freeCopyTo)(state, self, self_);
94+
95+ if (!keepdim) {
96+ THCTensor_ (squeeze1d)(state, self_, self_, dimension);
97+ }
9298}
9399
94100THC_API void
95- THCTensor_ (var)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag)
101+ THCTensor_ (var)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag, int keepdim )
96102{
97103 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , self_, src));
98104 THLongStorage *dim = THCTensor_ (newSizeOf)(state, src);
@@ -111,6 +117,10 @@ THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, long dimensio
111117
112118 THCTensor_ (free )(state, src);
113119 THCTensor_ (freeCopyTo)(state, self, self_);
120+
121+ if (!keepdim) {
122+ THCTensor_ (squeeze1d)(state, self_, self_, dimension);
123+ }
114124}
115125
116126THC_API accreal
@@ -146,28 +156,28 @@ THCTensor_(varall)(THCState *state, THCTensor *self)
146156}
147157
148158THC_API void
149- THCTensor_ (norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension)
159+ THCTensor_ (norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, int keepdim )
150160{
151161 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , self, src));
152162 if (THCNumerics<real>::eq (value, ScalarConvert<float , real>::to (0.0 ))) {
153163 THC_reduceDim (state, self, src,
154164 TensorNonZeroOp<real>(), ReduceAdd<real, real>(),
155- ScalarConvert<float , real>::to (0.0 ), dimension);
165+ ScalarConvert<float , real>::to (0.0 ), dimension, keepdim );
156166 } else if (THCNumerics<real>::eq (value, ScalarConvert<float , real>::to (1.0 ))) {
157167 THC_reduceDim (state, self, src,
158168 TensorNormOp<real, 1 >(value), ReduceAdd<real, real>(),
159- ScalarConvert<float , real>::to (0.0 ), dimension);
169+ ScalarConvert<float , real>::to (0.0 ), dimension, keepdim );
160170
161171 } else if (THCNumerics<real>::eq (value, ScalarConvert<float , real>::to (2.0 ))) {
162172 THC_reduceDim (state, self, src,
163173 TensorNormOp<real, 2 >(value), ReduceAdd<real, real>(),
164- ScalarConvert<float , real>::to (0.0 ), dimension);
174+ ScalarConvert<float , real>::to (0.0 ), dimension, keepdim );
165175 THCTensor_ (pow)(state, self, self, ScalarConvert<float , real>::to (0.5 ));
166176
167177 } else {
168178 THC_reduceDim (state, self, src,
169179 TensorNormOp<real, -1 >(value), ReduceAdd<real, real>(),
170- ScalarConvert<float , real>::to (0.0 ), dimension);
180+ ScalarConvert<float , real>::to (0.0 ), dimension, keepdim );
171181 THCTensor_ (pow)(state, self, self, THCNumerics<real>::cinv (value));
172182 }
173183
@@ -325,7 +335,8 @@ THCTensor_(max)(THCState *state,
325335 THCTensor *values,
326336 THCudaLongTensor *indices,
327337 THCTensor *src,
328- long dimension) {
338+ long dimension,
339+ int keepdim) {
329340 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 3 , values, indices, src));
330341
331342 thrust::pair<typename TensorUtils<THCTensor>::DataType, long >
@@ -334,7 +345,7 @@ THCTensor_(max)(THCState *state,
334345 THCNumerics<typename TensorUtils<THCTensor>::DataType>::min (), 1 );
335346
336347 return THC_reduceDimIndex (
337- state, values, indices, src, dimension, init,
348+ state, values, indices, src, dimension, keepdim, init,
338349 MaxValuePair<typename TensorUtils<THCTensor>::DataType, long >());
339350}
340351
@@ -343,7 +354,8 @@ THCTensor_(min)(THCState *state,
343354 THCTensor *values,
344355 THCudaLongTensor *indices,
345356 THCTensor *src,
346- long dimension) {
357+ long dimension,
358+ int keepdim) {
347359 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 3 , values, indices, src));
348360
349361 thrust::pair<typename TensorUtils<THCTensor>::DataType, long >
@@ -352,7 +364,7 @@ THCTensor_(min)(THCState *state,
352364 THCNumerics<typename TensorUtils<THCTensor>::DataType>::max (), 1 );
353365
354366 return THC_reduceDimIndex (
355- state, values, indices, src, dimension, init,
367+ state, values, indices, src, dimension, keepdim, init,
356368 MinValuePair<typename TensorUtils<THCTensor>::DataType, long >());
357369}
358370
0 commit comments