@@ -183,6 +183,93 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong
183183
184184#undef RUN
185185
186+ #define RUN (TYPE, DIMS, REAL ) \
187+ THCudaTensor_scatterAddKernel<TYPE, REAL, DIMS> \
188+ <<<grid, block, 0 , THCState_getCurrentStream(state)>>> ( \
189+ tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements);
190+
191+ void THCTensor_ (scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src) {
192+ THCAssertSameGPU (THCTensor_ (checkGPU)(state, 2 , tensor, src));
193+ THCAssertSameGPU (THCudaLongTensor_checkGPU (state, 1 , index));
194+
195+ THArgCheck (dim >= 0 && dim < THCTensor_ (nDimension)(state, tensor), 2 ,
196+ " Index dimension is out of bounds" );
197+ THArgCheck (THCudaLongTensor_nDimension (state, index) == THCTensor_ (nDimension)(state, src), 3 ,
198+ " Index tensor must have same dimensions as input tensor" );
199+ THArgCheck (THCTensor_ (nDimension)(state, src) == THCTensor_ (nDimension)(state, tensor), 4 ,
200+ " Input tensor must have same dimensions as output tensor" );
201+ THLongStorage *indexDims = THCudaLongTensor_newSizeOf (state, index);
202+ THArgCheck (THCTensor_ (isSize)(state, src, indexDims), 3 ,
203+ " Index tensor must have the same size as input tensor." );
204+ THLongStorage_free (indexDims);
205+
206+ for (int d = 0 ; d < THCTensor_ (nDimension)(state, tensor); d++) {
207+ if (d != dim) {
208+ THArgCheck (THCTensor_ (size)(state, tensor, d) == THCTensor_ (size)(state, src, d), 4 ,
209+ " Input tensor must have same size as output tensor apart from the specified dimension" );
210+ }
211+ }
212+
213+ THArgCheck (THCTensor_ (nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
214+ 1 , CUTORCH_DIM_WARNING);
215+
216+ const ptrdiff_t totalElements = THCudaLongTensor_nElement (state, index);
217+ const dim3 block = getApplyBlock ();
218+ dim3 grid;
219+ THArgCheck (getApplyGrid (state, totalElements, grid), 1 , CUTORCH_DIM_WARNING);
220+
221+ THCTensor* oldTensor = NULL ;
222+ if (TensorUtils<THCTensor>::overlappingIndices (state, tensor)) {
223+ oldTensor = tensor;
224+ tensor = THCTensor_ (newContiguous)(state, tensor);
225+ }
226+
227+ if (TensorUtils<THCTensor>::canUse32BitIndexMath (state, tensor) &&
228+ TensorUtils<THCTensor>::canUse32BitIndexMath (state, src) &&
229+ TensorUtils<THCudaLongTensor>::canUse32BitIndexMath (state, index)) {
230+ TensorInfo<real, unsigned int > tensorInfo =
231+ getTensorInfo<THCTensor, unsigned int >(state, tensor);
232+ TensorInfo<real, unsigned int > srcInfo =
233+ getTensorInfo<THCTensor, unsigned int >(state, src);
234+ TensorInfo<long , unsigned int > indexInfo =
235+ getTensorInfo<THCudaLongTensor, unsigned int >(state, index);
236+
237+ // Specialize for a small number of dimensions.
238+ switch (indexInfo.dims ) {
239+ case 1 :
240+ RUN (unsigned int , 1 , real);
241+ break ;
242+ case 2 :
243+ RUN (unsigned int , 2 , real);
244+ break ;
245+ case 3 :
246+ RUN (unsigned int , 3 , real);
247+ break ;
248+ default :
249+ RUN (unsigned int , -1 , real);
250+ break ;
251+ }
252+ } else {
253+ TensorInfo<real, unsigned long > tensorInfo =
254+ getTensorInfo<THCTensor, unsigned long >(state, tensor);
255+ TensorInfo<real, unsigned long > srcInfo =
256+ getTensorInfo<THCTensor, unsigned long >(state, src);
257+ TensorInfo<long , unsigned long > indexInfo =
258+ getTensorInfo<THCudaLongTensor, unsigned long >(state, index);
259+
260+ RUN (unsigned long , -1 , real)
261+ }
262+
263+ if (oldTensor) {
264+ TensorUtils<THCTensor>::copyIgnoringOverlaps (state, oldTensor, tensor);
265+ THCTensor_ (free )(state, tensor);
266+ tensor = oldTensor;
267+ }
268+ THCudaCheck (cudaGetLastError ());
269+ }
270+
271+ #undef RUN
272+
186273#define RUN (TYPE, DIMS, REAL ) \
187274 THCudaTensor_scatterFillKernel<TYPE, REAL, DIMS> \
188275 <<<grid, block, 0 , THCState_getCurrentStream(state)>>> ( \
0 commit comments