Skip to content

Commit 5ce4a4a

Browse files
committed
Merge commit '3f1f3f97343d2ab7eb522cac7330f6b7478bd4da'
2 parents 3e9caed + 3f1f3f9 commit 5ce4a4a

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

torch/lib/THC/THCTensorScatterGather.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "THCTensorMath.h"
22
#include "THCGeneral.h"
3+
#include "THCAtomics.cuh"
34
#include "THCApply.cuh"
45

56
// Compute the offsets into the given tensors for a linear index. For the 't2'
@@ -127,6 +128,33 @@ __global__ void THCudaTensor_scatterKernel(
127128
}
128129
}
129130

131+
template <typename IndexType, typename Real, int Dims>
132+
__global__ void THCudaTensor_scatterAddKernel(
133+
TensorInfo<Real, IndexType> tensor,
134+
TensorInfo<Real, IndexType> src,
135+
TensorInfo<long, IndexType> index,
136+
const int dim,
137+
const IndexType totalElements) {
138+
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
139+
linearId < totalElements;
140+
linearId += gridDim.x * blockDim.x) {
141+
IndexType tensorOffset = 0;
142+
IndexType srcOffset = 0;
143+
IndexType indexOffset = 0;
144+
145+
IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
146+
index, &indexOffset,
147+
src, &srcOffset,
148+
tensor, &tensorOffset);
149+
150+
long indexValue = index.data[indexOffset] - TH_INDEX_BASE;
151+
assert(indexValue >= 0 && indexValue < tensor.sizes[dim]);
152+
tensorOffset += indexValue * tensor.strides[dim];
153+
154+
atomicAdd(&tensor.data[tensorOffset], src.data[srcOffset]);
155+
}
156+
}
157+
130158
template <typename IndexType, typename Real, int Dims>
131159
__global__ void THCudaTensor_scatterFillKernel(
132160
TensorInfo<Real, IndexType> tensor,

torch/lib/THC/generic/THCTensorScatterGather.cu

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,93 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong
183183

184184
#undef RUN
185185

186+
#define RUN(TYPE, DIMS, REAL) \
187+
THCudaTensor_scatterAddKernel<TYPE, REAL, DIMS> \
188+
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
189+
tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements);
190+
191+
void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src) {
192+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
193+
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
194+
195+
THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2,
196+
"Index dimension is out of bounds");
197+
THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3,
198+
"Index tensor must have same dimensions as input tensor");
199+
THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4,
200+
"Input tensor must have same dimensions as output tensor");
201+
THLongStorage *indexDims = THCudaLongTensor_newSizeOf(state, index);
202+
THArgCheck(THCTensor_(isSize)(state, src, indexDims), 3,
203+
"Index tensor must have the same size as input tensor.");
204+
THLongStorage_free(indexDims);
205+
206+
for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
207+
if (d != dim) {
208+
THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 4,
209+
"Input tensor must have same size as output tensor apart from the specified dimension");
210+
}
211+
}
212+
213+
THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
214+
1, CUTORCH_DIM_WARNING);
215+
216+
const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index);
217+
const dim3 block = getApplyBlock();
218+
dim3 grid;
219+
THArgCheck(getApplyGrid(state, totalElements, grid), 1, CUTORCH_DIM_WARNING);
220+
221+
THCTensor* oldTensor = NULL;
222+
if (TensorUtils<THCTensor>::overlappingIndices(state, tensor)) {
223+
oldTensor = tensor;
224+
tensor = THCTensor_(newContiguous)(state, tensor);
225+
}
226+
227+
if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, tensor) &&
228+
TensorUtils<THCTensor>::canUse32BitIndexMath(state, src) &&
229+
TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, index)) {
230+
TensorInfo<real, unsigned int> tensorInfo =
231+
getTensorInfo<THCTensor, unsigned int>(state, tensor);
232+
TensorInfo<real, unsigned int> srcInfo =
233+
getTensorInfo<THCTensor, unsigned int>(state, src);
234+
TensorInfo<long, unsigned int> indexInfo =
235+
getTensorInfo<THCudaLongTensor, unsigned int>(state, index);
236+
237+
// Specialize for a small number of dimensions.
238+
switch (indexInfo.dims) {
239+
case 1:
240+
RUN(unsigned int, 1, real);
241+
break;
242+
case 2:
243+
RUN(unsigned int, 2, real);
244+
break;
245+
case 3:
246+
RUN(unsigned int, 3, real);
247+
break;
248+
default:
249+
RUN(unsigned int, -1, real);
250+
break;
251+
}
252+
} else {
253+
TensorInfo<real, unsigned long> tensorInfo =
254+
getTensorInfo<THCTensor, unsigned long>(state, tensor);
255+
TensorInfo<real, unsigned long> srcInfo =
256+
getTensorInfo<THCTensor, unsigned long>(state, src);
257+
TensorInfo<long, unsigned long> indexInfo =
258+
getTensorInfo<THCudaLongTensor, unsigned long>(state, index);
259+
260+
RUN(unsigned long, -1, real)
261+
}
262+
263+
if (oldTensor) {
264+
TensorUtils<THCTensor>::copyIgnoringOverlaps(state, oldTensor, tensor);
265+
THCTensor_(free)(state, tensor);
266+
tensor = oldTensor;
267+
}
268+
THCudaCheck(cudaGetLastError());
269+
}
270+
271+
#undef RUN
272+
186273
#define RUN(TYPE, DIMS, REAL) \
187274
THCudaTensor_scatterFillKernel<TYPE, REAL, DIMS> \
188275
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \

torch/lib/THC/generic/THCTensorScatterGather.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
THC_API void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index);
66
THC_API void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src);
7+
THC_API void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src);
78
THC_API void THCTensor_(scatterFill)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, real value);
89

910
#endif

0 commit comments

Comments
 (0)