@@ -53,7 +53,7 @@ void THCudaInit(THCState* state)
5353
5454 /* Allocate scratch space for each stream */
5555 res -> devScratchSpacePerStream = (void * * ) malloc (sizeof (void * ));
56- THCudaCheck (cudaMalloc ( & res -> devScratchSpacePerStream [0 ],
56+ THCudaCheck (THCudaMalloc ( state , & res -> devScratchSpacePerStream [0 ],
5757 sizePerStream ));
5858 }
5959
@@ -71,6 +71,10 @@ void THCudaInit(THCState* state)
7171 THCState_reserveBlasHandles (state , 1 );
7272 state -> currentPerDeviceBlasHandle = 1 ;
7373 state -> currentBlasHandle = THCState_getDeviceBlasHandle (state , device , 1 );
74+
75+ state -> cutorchGCFunction = NULL ;
76+ state -> cutorchGCData = NULL ;
77+ state -> heapSoftmax = 300000000 ; // 300MB, adjusted upward dynamically
7478}
7579
7680void THCudaShutdown (THCState * state )
@@ -100,7 +104,7 @@ void THCudaShutdown(THCState* state)
100104 /* Free per-stream scratch space; starts at 0 because there is space for
101105 the default stream as well*/
102106 for (int stream = 0 ; stream <= state -> numUserStreams ; ++ stream ) {
103- THCudaCheck (cudaFree ( THCState_getDeviceScratchSpace (state , dev , stream )));
107+ THCudaCheck (THCudaFree ( state , THCState_getDeviceScratchSpace (state , dev , stream )));
104108 }
105109
106110 free (state -> resourcesPerDevice [dev ].streams );
@@ -199,7 +203,7 @@ void THCState_reserveStreams(THCState* state, int numStreams)
199203 newStreams [stream ] = NULL ;
200204 THCudaCheck (cudaStreamCreate (newStreams + stream ));
201205 newScratchSpace [stream ] = NULL ;
202- THCudaCheck (cudaMalloc ( & newScratchSpace [stream ], scratchSpaceSize ));
206+ THCudaCheck (THCudaMalloc ( state , & newScratchSpace [stream ], scratchSpaceSize ));
203207 }
204208
205209 THCCudaResourcesPerDevice * res = THCState_getDeviceResourcePtr (state , dev );
@@ -489,4 +493,58 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
489493 }
490494}
491495
496+ static long heapSize = 0 ; // not thread-local
497+ static const double heapSoftmaxGrowthThresh = 0.8 ; // grow softmax if >80% max after GC
498+ static const double heapSoftmaxGrowthFactor = 1.4 ; // grow softmax by 40%
499+
500+ void THCSetGCHandler (THCState * state , void (* cutorchGCFunction_ )(void * data ), void * data )
501+ {
502+ state -> cutorchGCFunction = cutorchGCFunction_ ;
503+ state -> cutorchGCData = data ;
504+ }
505+
506+ cudaError_t THCudaMalloc (THCState * state , void * * ptr , size_t size )
507+ {
508+ THCudaCheck (cudaGetLastError ());
509+ cudaError_t err = cudaMalloc (ptr , size );
510+ if (state -> cutorchGCFunction != NULL && err != cudaSuccess ) {
511+ cudaGetLastError (); // reset OOM error
512+ (state -> cutorchGCFunction )(state -> cutorchGCData );
513+ err = cudaMalloc (ptr , size );
514+ }
515+ return err ;
516+ }
517+
518+ cudaError_t THCudaFree (THCState * state , void * ptr )
519+ {
520+ cudaError_t err = cudaFree (ptr );
521+ return err ;
522+ }
523+
524+ // Here we maintain a dynamic softmax threshold for THC-allocated storages.
525+ // When THC heap size goes above this softmax, the GC hook is triggered.
526+ // If heap size is above 80% of the softmax after GC, then the softmax is
527+ // increased.
528+ static void maybeTriggerGC (THCState * state , long curHeapSize ) {
529+ if (state -> cutorchGCFunction != NULL && curHeapSize > state -> heapSoftmax ) {
530+ (state -> cutorchGCFunction )(state -> cutorchGCData );
531+ long newHeapSize = THAtomicGetLong (& heapSize );
532+ if (newHeapSize > state -> heapSoftmax * heapSoftmaxGrowthThresh ) {
533+ state -> heapSoftmax = state -> heapSoftmax * heapSoftmaxGrowthFactor ;
534+ }
535+ }
536+ }
537+
538+ void THCHeapUpdate (THCState * state , long size ) {
539+ long newHeapSize = THAtomicAddLong (& heapSize , size ) + size ;
540+ #ifdef THC_CHECK_HEAP_UPDATE
541+ if (newHeapSize < 0 ) {
542+ THError ("Internal error: THC heapSize < 0" );
543+ }
544+ #endif
545+ if (size > 0 ) {
546+ maybeTriggerGC (state , newHeapSize );
547+ }
548+ }
549+
492550#undef GLOBAL_SCRATCH_SPACE_PER_SM_STREAM
0 commit comments