@@ -175,23 +175,9 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
175175 real *data = THCTensor_ (data)(state, result);
176176
177177 // Kernel Parameter
178- CatArrInputTensor<real, unsigned int > stackInputs[CAT_ARRAY_BATCH_SIZE];
179- CatArrInputTensor<real, unsigned int > *d_inputs;
180-
181- // Attempt to re-use stream's scratch space for the input metadata
182- bool usedScratch = false ;
183178 size_t tensorMetadataSize = sizeof (CatArrInputTensor<real, unsigned int >) * CAT_ARRAY_BATCH_SIZE;
184- if (THCState_getCurrentDeviceScratchSpaceSize (state) > tensorMetadataSize) {
185- void * space = THCState_getCurrentDeviceScratchSpace (state);
186- if (space) {
187- d_inputs = (CatArrInputTensor<real, unsigned int > *) space;
188- usedScratch = true ;
189- }
190- }
191- if (!usedScratch) {
192- // Fallback to allocating GPU memory
193- THCudaCheck (THCudaMalloc (state, (void **) &d_inputs, tensorMetadataSize));
194- }
179+ CatArrInputTensor<real, unsigned int > *d_inputs;
180+ THCudaCheck (THCudaMalloc (state, (void **) &d_inputs, tensorMetadataSize));
195181
196182 OutputTensorSizeStride<unsigned int , CAT_ARRAY_MAX_INPUT_DIMS> param;
197183
@@ -201,13 +187,17 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
201187 param.outputStride [i] = THCTensor_ (stride)(state, result, i);
202188 }
203189
190+ THCStream* stream = THCState_getStream (state);
191+
204192 // Template Declarations for dim = 1, 2, 3, 4
205193#define HANDLE_CASE (DIMS ) \
206- CatArrayBatchedCopy<real, unsigned int , DIMS><<<applyGrid, applyBlock, 0 , THCState_getCurrentStream(state) >>> (data, d_inputs, param, cat_dimension, param.outputStride [cat_dimension]);
194+ CatArrayBatchedCopy<real, unsigned int , DIMS><<<applyGrid, applyBlock, 0 , stream->stream >>> (data, d_inputs, param, cat_dimension, param.outputStride [cat_dimension]);
207195
208196 // Now we loop
209197 offset = 0 ;
210198 for (i = 0 ; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) {
199+ // Re-allocate stackInputs every iteration to avoid read-after-write hazard
200+ CatArrInputTensor<real, unsigned int >* stackInputs = (CatArrInputTensor<real, unsigned int >*) THCudaHostAlloc (state, tensorMetadataSize);
211201 cohortMax = 0 ;
212202 for (j = 0 ; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) {
213203 long dimSize = cat_dimension < THCTensor_ (nDimension)(state, inputs[i+j])
@@ -223,7 +213,14 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
223213 // update offset
224214 offset += dimSize;
225215 }
226- THCudaCheck (cudaMemcpy (d_inputs, stackInputs, j * sizeof (CatArrInputTensor<real, unsigned int >), cudaMemcpyHostToDevice));
216+ THCudaCheck (cudaMemcpyAsync (
217+ d_inputs,
218+ stackInputs,
219+ j * sizeof (CatArrInputTensor<real, unsigned int >),
220+ cudaMemcpyHostToDevice,
221+ stream->stream ));
222+ THCudaHostRecord (state, stackInputs);
223+ THCudaHostFree (state, stackInputs);
227224
228225 // Next, let's consider how we set our kernel launch parameters.
229226 // We borrow from THCApply, which the kernel's internal indexing
@@ -256,9 +253,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
256253 }
257254 THCudaCheck (cudaGetLastError ());
258255 }
259- if (!usedScratch) {
260- THCudaCheck (THCudaFree (state, (void *)d_inputs));
261- }
256+ THCudaCheck (THCudaFree (state, d_inputs));
262257#undef HANDLE_CASE
263258 } else {
264259 offset = 0 ;
@@ -399,10 +394,10 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n
399394 if (THCTensor_ (nElement)(state, r_) != n) THCTensor_ (resize1d)(state, r_, n);
400395 if (n == 1 ) THCTensor_ (fill)(state, r_, a);
401396 else {
402- THCTensor *r = THCTensor_ (isContiguous)(state, r_)
397+ THCTensor *r = THCTensor_ (isContiguous)(state, r_)
403398 ? r_ // if r_ is contiguous we can direct work on it
404399 : THCTensor_ (newContiguous)(state, r_);
405- real step = THCNumerics<real>::div (THCNumerics<real>::sub (b, a),
400+ real step = THCNumerics<real>::div (THCNumerics<real>::sub (b, a),
406401 ScalarConvert<long ,real>::to (n - 1 ));
407402 LinspaceOp<real> linspace_method (a, step);
408403 thrust::device_ptr<real> data_ (THCTensor_ (data)(state, r));
@@ -420,10 +415,10 @@ void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n
420415 if (THCTensor_ (nElement)(state, r_) != n) THCTensor_ (resize1d)(state, r_, n);
421416 if (n == 1 ) THCTensor_ (fill)(state, r_, THCNumerics<real>::exp10 (a));
422417 else {
423- THCTensor *r = THCTensor_ (isContiguous)(state, r_)
424- ? r_
418+ THCTensor *r = THCTensor_ (isContiguous)(state, r_)
419+ ? r_
425420 : THCTensor_ (newContiguous)(state, r_);
426- real step = THCNumerics<real>::div (THCNumerics<real>::sub (b, a),
421+ real step = THCNumerics<real>::div (THCNumerics<real>::sub (b, a),
427422 ScalarConvert<long ,real>::to (n - 1 ));
428423 LogspaceOp<real> logspace_method (a, step);
429424 thrust::device_ptr<real> data_ (THCTensor_ (data)(state, r));
@@ -444,8 +439,8 @@ void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xma
444439 , 2 , " upper bound and larger bound incoherent with step sign" );
445440 ptrdiff_t size = (ptrdiff_t ) (((xmax - xmin) / step) + 1 );
446441 if (THCTensor_ (nElement)(state, r_) != size) THCTensor_ (resize1d)(state, r_, size);
447- THCTensor *r = THCTensor_ (isContiguous)(state, r_)
448- ? r_
442+ THCTensor *r = THCTensor_ (isContiguous)(state, r_)
443+ ? r_
449444 : THCTensor_ (newContiguous)(state, r_);
450445 LinspaceOp<real,accreal> linspace_method (xmin, step);
451446 thrust::device_ptr<real> data_ (THCTensor_ (data)(state, r));
0 commit comments