Skip to content

Commit d5e8210

Browse files
colesburysoumith
authored andcommitted
Make torch.cat not synchronize the host and device
1 parent 5f308b5 commit d5e8210

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

THCGeneral.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,27 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
848848
return allocator->free(allocator->state, ptr);
849849
}
850850

851+
void* THCudaHostAlloc(THCState *state, size_t size)
852+
{
853+
THCudaCheck(cudaGetLastError());
854+
THAllocator* allocator = state->cudaHostAllocator;
855+
return allocator->malloc(NULL, size);
856+
}
857+
858+
void THCudaHostFree(THCState *state, void *ptr)
859+
{
860+
THAllocator* allocator = state->cudaHostAllocator;
861+
return allocator->free(NULL, ptr);
862+
}
863+
864+
void THCudaHostRecord(THCState *state, void *ptr)
865+
{
866+
if (state->cudaHostAllocator == &THCCachingHostAllocator) {
867+
THCStream* stream = THCState_getStream(state);
868+
THCCachingHostAllocator_recordEvent(ptr, stream);
869+
}
870+
}
871+
851872
cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes)
852873
{
853874
size_t cachedBytes = 0;
@@ -932,4 +953,3 @@ float THC_half2float(half h)
932953
TH_halfbits2float(&h.x, &f);
933954
return f;
934955
}
935-

THCGeneral.h.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ THC_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const
204204

205205
THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
206206
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
207+
THC_API void* THCudaHostAlloc(THCState *state, size_t size);
208+
THC_API void THCudaHostFree(THCState *state, void *ptr);
209+
THC_API void THCudaHostRecord(THCState *state, void *ptr);
210+
207211
THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes);
208212
THC_API void THCSetGCHandler(THCState *state,
209213
void (*torchGCHandlerFunction)(void *data),

generic/THCTensorMath.cu

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -175,23 +175,9 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
175175
real *data = THCTensor_(data)(state, result);
176176

177177
// Kernel Parameter
178-
CatArrInputTensor<real, unsigned int> stackInputs[CAT_ARRAY_BATCH_SIZE];
179-
CatArrInputTensor<real, unsigned int> *d_inputs;
180-
181-
// Attempt to re-use stream's scratch space for the input metadata
182-
bool usedScratch = false;
183178
size_t tensorMetadataSize = sizeof(CatArrInputTensor<real, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
184-
if (THCState_getCurrentDeviceScratchSpaceSize(state) > tensorMetadataSize) {
185-
void* space = THCState_getCurrentDeviceScratchSpace(state);
186-
if (space) {
187-
d_inputs = (CatArrInputTensor<real, unsigned int> *) space;
188-
usedScratch = true;
189-
}
190-
}
191-
if (!usedScratch) {
192-
// Fallback to allocating GPU memory
193-
THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize));
194-
}
179+
CatArrInputTensor<real, unsigned int> *d_inputs;
180+
THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize));
195181

196182
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
197183

@@ -201,13 +187,17 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
201187
param.outputStride[i] = THCTensor_(stride)(state, result, i);
202188
}
203189

190+
THCStream* stream = THCState_getStream(state);
191+
204192
// Template Declarations for dim = 1, 2, 3, 4
205193
#define HANDLE_CASE(DIMS) \
206-
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, THCState_getCurrentStream(state)>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
194+
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
207195

208196
// Now we loop
209197
offset = 0;
210198
for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) {
199+
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
200+
CatArrInputTensor<real, unsigned int>* stackInputs = (CatArrInputTensor<real, unsigned int>*) THCudaHostAlloc(state, tensorMetadataSize);
211201
cohortMax = 0;
212202
for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) {
213203
long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[i+j])
@@ -223,7 +213,14 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
223213
// update offset
224214
offset += dimSize;
225215
}
226-
THCudaCheck(cudaMemcpy(d_inputs, stackInputs, j * sizeof(CatArrInputTensor<real, unsigned int>), cudaMemcpyHostToDevice));
216+
THCudaCheck(cudaMemcpyAsync(
217+
d_inputs,
218+
stackInputs,
219+
j * sizeof(CatArrInputTensor<real, unsigned int>),
220+
cudaMemcpyHostToDevice,
221+
stream->stream));
222+
THCudaHostRecord(state, stackInputs);
223+
THCudaHostFree(state, stackInputs);
227224

228225
// Next, let's consider how we set our kernel launch parameters.
229226
// We borrow from THCApply, which the kernel's internal indexing
@@ -256,9 +253,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
256253
}
257254
THCudaCheck(cudaGetLastError());
258255
}
259-
if (!usedScratch) {
260-
THCudaCheck(THCudaFree(state, (void *)d_inputs));
261-
}
256+
THCudaCheck(THCudaFree(state, d_inputs));
262257
#undef HANDLE_CASE
263258
} else {
264259
offset = 0;
@@ -399,10 +394,10 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n
399394
if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n);
400395
if (n == 1) THCTensor_(fill)(state, r_, a);
401396
else {
402-
THCTensor *r = THCTensor_(isContiguous)(state, r_)
397+
THCTensor *r = THCTensor_(isContiguous)(state, r_)
403398
? r_ // if r_ is contiguous we can direct work on it
404399
: THCTensor_(newContiguous)(state, r_);
405-
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
400+
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
406401
ScalarConvert<long,real>::to(n - 1));
407402
LinspaceOp<real> linspace_method(a, step);
408403
thrust::device_ptr<real> data_(THCTensor_(data)(state, r));
@@ -420,10 +415,10 @@ void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n
420415
if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n);
421416
if (n == 1) THCTensor_(fill)(state, r_, THCNumerics<real>::exp10(a));
422417
else {
423-
THCTensor *r = THCTensor_(isContiguous)(state, r_)
424-
? r_
418+
THCTensor *r = THCTensor_(isContiguous)(state, r_)
419+
? r_
425420
: THCTensor_(newContiguous)(state, r_);
426-
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
421+
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
427422
ScalarConvert<long,real>::to(n - 1));
428423
LogspaceOp<real> logspace_method(a, step);
429424
thrust::device_ptr<real> data_(THCTensor_(data)(state, r));
@@ -444,8 +439,8 @@ void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xma
444439
, 2, "upper bound and larger bound incoherent with step sign");
445440
ptrdiff_t size = (ptrdiff_t) (((xmax - xmin) / step) + 1);
446441
if (THCTensor_(nElement)(state, r_) != size) THCTensor_(resize1d)(state, r_, size);
447-
THCTensor *r = THCTensor_(isContiguous)(state, r_)
448-
? r_
442+
THCTensor *r = THCTensor_(isContiguous)(state, r_)
443+
? r_
449444
: THCTensor_(newContiguous)(state, r_);
450445
LinspaceOp<real,accreal> linspace_method(xmin, step);
451446
thrust::device_ptr<real> data_(THCTensor_(data)(state, r));

0 commit comments

Comments
 (0)