Skip to content

Commit a19f804

Browse files
committed
Merge pull request pytorch#206 from dominikgrewe/scattergather
CUDA implementations for scatter & gather.
2 parents 20562f0 + 9123d31 commit a19f804

File tree

3 files changed

+407
-0
lines changed

3 files changed

+407
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ SET(src-cuda
9191
THCTensorIndex.cu
9292
THCTensorConv.cu
9393
THCTensorRandom.cu
94+
THCTensorScatterGather.cu
9495
THCApply.cu
9596
THCTensorSort.cu
9697
)

THCTensorMath.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ THC_API void THCudaTensor_maskedFillByte(THCState* state, THCudaTensor *tensor,
130130
THC_API void THCudaTensor_maskedCopyByte(THCState* state, THCudaTensor *tensor, THByteTensor *mask, THCudaTensor *src);
131131
THC_API void THCudaTensor_maskedSelectByte(THCState* state, THCudaTensor *tensor, THCudaTensor *src, THByteTensor *mask);
132132

133+
THC_API void THCudaTensor_gather(THCState* state, THCudaTensor *tensor, THCudaTensor *src, int dim, THCudaTensor *index);
134+
THC_API void THCudaTensor_scatter(THCState* state, THCudaTensor *tensor, int dim, THCudaTensor *index, THCudaTensor *src);
135+
THC_API void THCudaTensor_scatterFill(THCState* state, THCudaTensor *tensor, int dim, THCudaTensor *index, float value);
136+
133137
THC_API int THCudaTensor_logicalall(THCState *state, THCudaTensor *self);
134138
THC_API int THCudaTensor_logicalany(THCState *state, THCudaTensor *self);
135139

0 commit comments

Comments
 (0)