Skip to content

Commit 540c947

Browse files
committed
cutorch gc
1 parent d874a07 commit 540c947

File tree

7 files changed

+115
-29
lines changed

7 files changed

+115
-29
lines changed

THCGeneral.c

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void THCudaInit(THCState* state)
5353

5454
/* Allocate scratch space for each stream */
5555
res->devScratchSpacePerStream = (void**) malloc(sizeof(void*));
56-
THCudaCheck(cudaMalloc(&res->devScratchSpacePerStream[0],
56+
THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0],
5757
sizePerStream));
5858
}
5959

@@ -71,6 +71,10 @@ void THCudaInit(THCState* state)
7171
THCState_reserveBlasHandles(state, 1);
7272
state->currentPerDeviceBlasHandle = 1;
7373
state->currentBlasHandle = THCState_getDeviceBlasHandle(state, device, 1);
74+
75+
state->cutorchGCFunction = NULL;
76+
state->cutorchGCData = NULL;
77+
state->heapSoftmax = 300000000; // 300MB, adjusted upward dynamically
7478
}
7579

7680
void THCudaShutdown(THCState* state)
@@ -100,7 +104,7 @@ void THCudaShutdown(THCState* state)
100104
/* Free per-stream scratch space; starts at 0 because there is space for
101105
the default stream as well*/
102106
for (int stream = 0; stream <= state->numUserStreams; ++stream) {
103-
THCudaCheck(cudaFree(THCState_getDeviceScratchSpace(state, dev, stream)));
107+
THCudaCheck(THCudaFree(state, THCState_getDeviceScratchSpace(state, dev, stream)));
104108
}
105109

106110
free(state->resourcesPerDevice[dev].streams);
@@ -199,7 +203,7 @@ void THCState_reserveStreams(THCState* state, int numStreams)
199203
newStreams[stream] = NULL;
200204
THCudaCheck(cudaStreamCreate(newStreams + stream));
201205
newScratchSpace[stream] = NULL;
202-
THCudaCheck(cudaMalloc(&newScratchSpace[stream], scratchSpaceSize));
206+
THCudaCheck(THCudaMalloc(state, &newScratchSpace[stream], scratchSpaceSize));
203207
}
204208

205209
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
@@ -489,4 +493,58 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
489493
}
490494
}
491495

496+
static long heapSize = 0; // not thread-local
497+
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
498+
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%
499+
500+
void THCSetGCHandler(THCState *state, void (*cutorchGCFunction_)(void *data), void *data )
501+
{
502+
state->cutorchGCFunction = cutorchGCFunction_;
503+
state->cutorchGCData = data;
504+
}
505+
506+
cudaError_t THCudaMalloc(THCState *state, void** ptr, size_t size)
507+
{
508+
THCudaCheck(cudaGetLastError());
509+
cudaError_t err = cudaMalloc(ptr, size);
510+
if (state->cutorchGCFunction != NULL && err != cudaSuccess) {
511+
cudaGetLastError(); // reset OOM error
512+
(state->cutorchGCFunction)(state->cutorchGCData);
513+
err = cudaMalloc(ptr, size);
514+
}
515+
return err;
516+
}
517+
518+
cudaError_t THCudaFree(THCState *state, void *ptr)
519+
{
520+
cudaError_t err = cudaFree(ptr);
521+
return err;
522+
}
523+
524+
// Here we maintain a dynamic softmax threshold for THC-allocated storages.
525+
// When THC heap size goes above this softmax, the GC hook is triggered.
526+
// If heap size is above 80% of the softmax after GC, then the softmax is
527+
// increased.
528+
static void maybeTriggerGC(THCState *state, long curHeapSize) {
529+
if (state->cutorchGCFunction != NULL && curHeapSize > state->heapSoftmax) {
530+
(state->cutorchGCFunction)(state->cutorchGCData);
531+
long newHeapSize = THAtomicGetLong(&heapSize);
532+
if (newHeapSize > state->heapSoftmax * heapSoftmaxGrowthThresh) {
533+
state->heapSoftmax = state->heapSoftmax * heapSoftmaxGrowthFactor;
534+
}
535+
}
536+
}
537+
538+
void THCHeapUpdate(THCState *state, long size) {
539+
long newHeapSize = THAtomicAddLong(&heapSize, size) + size;
540+
#ifdef THC_CHECK_HEAP_UPDATE
541+
if (newHeapSize < 0) {
542+
THError("Internal error: THC heapSize < 0");
543+
}
544+
#endif
545+
if (size > 0) {
546+
maybeTriggerGC(state, newHeapSize);
547+
}
548+
}
549+
492550
#undef GLOBAL_SCRATCH_SPACE_PER_SM_STREAM

THCGeneral.h.in

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ typedef struct THCState
7474
int currentPerDeviceBlasHandle;
7575
/* Allocator using cudaMallocHost. */
7676
THAllocator* cudaHostAllocator;
77+
78+
void (*cutorchGCFunction)(void *data);
79+
void *cutorchGCData;
80+
long heapSoftmax;
7781
} THCState;
7882

7983
THC_API void THCudaInit(THCState* state);
@@ -116,4 +120,11 @@ THC_API size_t THCState_getDeviceScratchSpaceSize(THCState* state, int device);
116120
THC_API void __THCudaCheck(cudaError_t err, const char *file, const int line);
117121
THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line);
118122

123+
THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
124+
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
125+
THC_API void THCSetGCHandler(THCState *state,
126+
void (*torchGCHandlerFunction)(void *data),
127+
void *data );
128+
THC_API void THCHeapUpdate(THCState *state, long size);
129+
119130
#endif

THCStorage.c

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,15 @@ THCudaStorage* THCudaStorage_newWithSize(THCState *state, long size)
3333
if(size > 0)
3434
{
3535
THCudaStorage *storage = (THCudaStorage*)THAlloc(sizeof(THCudaStorage));
36-
THCudaCheck(cudaMalloc((void**)&(storage->data), size * sizeof(float)));
36+
37+
// update heap *before* attempting malloc, to free space for the malloc
38+
THCHeapUpdate(state, size * sizeof(float));
39+
cudaError_t err =
40+
THCudaMalloc(state, (void**)&(storage->data), size * sizeof(float));
41+
if(err != cudaSuccess){
42+
THCHeapUpdate(state, -size * sizeof(float));
43+
}
44+
THCudaCheck(err);
3745

3846
storage->size = size;
3947
storage->refcount = 1;
@@ -110,7 +118,8 @@ void THCudaStorage_free(THCState *state, THCudaStorage *self)
110118
if (THAtomicDecrementRef(&self->refcount))
111119
{
112120
if(self->flag & TH_STORAGE_FREEMEM) {
113-
THCudaCheck(cudaFree(self->data));
121+
THCHeapUpdate(state, -self->size * sizeof(float));
122+
THCudaCheck(THCudaFree(state, self->data));
114123
}
115124
THFree(self);
116125
}

THCStorage.cu

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,31 @@ void THCudaStorage_resize(THCState *state, THCudaStorage *self, long size)
2626
if(size == 0)
2727
{
2828
if(self->flag & TH_STORAGE_FREEMEM) {
29-
THCudaCheck(cudaFree(self->data));
29+
THCudaCheck(THCudaFree(state, self->data));
30+
THCHeapUpdate(state, -self->size * sizeof(float));
3031
}
3132
self->data = NULL;
3233
self->size = 0;
3334
}
3435
else
3536
{
3637
float *data = NULL;
37-
THCudaCheck(cudaMalloc((void**)(&data), size * sizeof(float)));
38+
// update heap *before* attempting malloc, to free space for the malloc
39+
THCHeapUpdate(state, size * sizeof(float));
40+
cudaError_t err = THCudaMalloc(state, (void**)(&data), size * sizeof(float));
41+
if(err != cudaSuccess) {
42+
THCHeapUpdate(state, -size * sizeof(float));
43+
}
44+
THCudaCheck(err);
3845

3946
if (self->data) {
4047
THCudaCheck(cudaMemcpyAsync(data,
4148
self->data,
4249
THMin(self->size, size) * sizeof(float),
4350
cudaMemcpyDeviceToDevice,
4451
THCState_getCurrentStream(state)));
45-
THCudaCheck(cudaFree(self->data));
52+
THCudaCheck(THCudaFree(state, self->data));
53+
THCHeapUpdate(state, -self->size * sizeof(float));
4654
}
4755

4856
self->data = data;

THCTensorIndex.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim, THCuda
115115
dim3 nthreads(16, 16);
116116
dim3 nblocks(ceil((float)nRes / nIndex / (16*16)));
117117

118-
THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long)));
118+
THCudaCheck(THCudaMalloc(state, (void**)&stride_, res_->nDimension * sizeof(long)));
119119
THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice));
120120

121121
THCudaTensor_kernel_indexCopy<<<nblocks, nthreads, 0, THCState_getCurrentStream(state)>>>(
@@ -125,7 +125,7 @@ void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim, THCuda
125125
THCudaTensor_nElement(state, src), res_->size[dim]
126126
);
127127

128-
THCudaCheck(cudaFree(stride_));
128+
THCudaCheck(THCudaFree(state, stride_));
129129
THCudaTensor_free(state, indices);
130130
THCudaTensor_free(state, src);
131131
}
@@ -159,15 +159,15 @@ void THCudaTensor_indexFill(THCState *state, THCudaTensor *res_, int dim, THCuda
159159
dim3 nthreads(16, 16);
160160
dim3 nblocks(ceil((float)nRes / nIndex / (16*16)));
161161

162-
THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long)));
162+
THCudaCheck(THCudaMalloc(state, (void**)&stride_, res_->nDimension * sizeof(long)));
163163
THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice));
164164

165165
THCudaTensor_kernel_indexFill<<<nblocks, nthreads, 0, THCState_getCurrentStream(state)>>>(
166166
THCudaTensor_data(state, res_), stride_, THCudaTensor_data(state, indices),
167167
res_->nDimension, dim, nIndex, nRes, res_->size[dim], val
168168
);
169169

170-
THCudaCheck(cudaFree(stride_));
170+
THCudaCheck(THCudaFree(state, stride_));
171171
THCudaTensor_free(state, indices);
172172
}
173173

@@ -299,7 +299,7 @@ void THCudaTensor_indexSelect(THCState *state, THCudaTensor *res_, THCudaTensor
299299
dim3 nthreads(16, 16);
300300
dim3 nblocks(ceil((float)nRes / nIndex / (16*16)));
301301

302-
THCudaCheck(cudaMalloc((void**)&stride_, src->nDimension * sizeof(long)));
302+
THCudaCheck(THCudaMalloc(state, (void**)&stride_, src->nDimension * sizeof(long)));
303303
THCudaCheck(cudaMemcpy(stride_, src->stride, src->nDimension * sizeof(long), cudaMemcpyHostToDevice));
304304

305305
THCudaTensor_kernel_indexSelect<<<nblocks, nthreads, 0, stream>>>(
@@ -308,7 +308,7 @@ void THCudaTensor_indexSelect(THCState *state, THCudaTensor *res_, THCudaTensor
308308
src->nDimension, dim, nIndex, nRes, src->size[dim]
309309
);
310310

311-
THCudaCheck(cudaFree(stride_));
311+
THCudaCheck(THCudaFree(state, stride_));
312312
THCudaTensor_free(state, indices);
313313
THCudaTensor_freeCopyTo(state, res, res_);
314314
}

THCTensorMathBlas.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,9 @@ void THCudaTensor_baddbmm(THCState *state, THCudaTensor *result, float beta, THC
341341
// Copy pointers to device.
342342
const float **d_matrices1, **d_matrices2;
343343
float **d_result_matrices;
344-
THCudaCheck(cudaMalloc((void**)&d_matrices1, matrices_size));
345-
THCudaCheck(cudaMalloc((void**)&d_matrices2, matrices_size));
346-
THCudaCheck(cudaMalloc((void**)&d_result_matrices, matrices_size));
344+
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
345+
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
346+
THCudaCheck(THCudaMalloc(state, (void**)&d_result_matrices, matrices_size));
347347

348348
THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
349349
cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
@@ -366,9 +366,9 @@ void THCudaTensor_baddbmm(THCState *state, THCudaTensor *result, float beta, THC
366366
d_result_matrices, ldc,
367367
num_batches);
368368

369-
cudaFree(d_matrices1);
370-
cudaFree(d_matrices2);
371-
cudaFree(d_result_matrices);
369+
THCudaFree(state, d_matrices1);
370+
THCudaFree(state, d_matrices2);
371+
THCudaFree(state, d_result_matrices);
372372
THFree(matrices1);
373373
THFree(matrices2);
374374
THFree(result_matrices);

THCTensorRandom.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,23 @@
1515
#define BLOCK_SIZE 256
1616

1717
/* Sets up generator. Allocates but does not create the generator states. */
18-
__host__ void initializeGenerator(Generator* gen)
18+
__host__ void initializeGenerator(THCState *state, Generator* gen)
1919
{
20-
THCudaCheck(cudaMalloc((void**)&gen->gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
21-
THCudaCheck(cudaMalloc((void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
20+
THCudaCheck(THCudaMalloc(state, (void**)&gen->gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
21+
THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
2222
}
2323

2424
/* Frees memory allocated during setup. */
25-
__host__ void destroyGenerator(Generator* gen)
25+
__host__ void destroyGenerator(THCState *state, Generator* gen)
2626
{
2727
if (gen->gen_states)
2828
{
29-
THCudaCheck(cudaFree(gen->gen_states));
29+
THCudaCheck(THCudaFree(state, gen->gen_states));
3030
gen->gen_states = NULL;
3131
}
3232
if (gen->kernel_params)
3333
{
34-
THCudaCheck(cudaFree(gen->kernel_params));
34+
THCudaCheck(THCudaFree(state, gen->kernel_params));
3535
gen->kernel_params = NULL;
3636
}
3737
}
@@ -66,7 +66,7 @@ __host__ void THCRandom_init(THCState* state, int devices, int current_device)
6666
rng_state->current_gen = &rng_state->gen[current_device];
6767
// Initialize the generator for the current device. Other generators will be
6868
// initialized on-demand in THCRandom_setGenerator.
69-
initializeGenerator(rng_state->current_gen);
69+
initializeGenerator(state, rng_state->current_gen);
7070
THCRandom_seed(state);
7171
}
7272

@@ -77,7 +77,7 @@ __host__ void THCRandom_shutdown(THCState* state)
7777
if (rng_state->gen == NULL) return;
7878
for (int i = 0; i < rng_state->num_devices; ++i)
7979
{
80-
destroyGenerator(&rng_state->gen[i]);
80+
destroyGenerator(state, &rng_state->gen[i]);
8181
}
8282
free(rng_state->gen);
8383
rng_state->gen = NULL;
@@ -92,7 +92,7 @@ __host__ void THCRandom_setGenerator(THCState* state, int device)
9292
rng_state->current_gen = &rng_state->gen[device];
9393
if (rng_state->current_gen->initf == 0)
9494
{
95-
initializeGenerator(rng_state->current_gen);
95+
initializeGenerator(state, rng_state->current_gen);
9696
THCRandom_seed(state);
9797
}
9898
}

0 commit comments

Comments
 (0)