@@ -74,7 +74,8 @@ void THCudaInit(THCState* state)
7474
7575 state -> cutorchGCFunction = NULL ;
7676 state -> cutorchGCData = NULL ;
77- state -> heapSoftmax = 300000000 ; // 300MB, adjusted upward dynamically
77+ state -> heapSoftmax = 3e8 ; // 300MB, adjusted upward dynamically
78+ state -> heapDelta = 0 ;
7879}
7980
8081void THCudaShutdown (THCState * state )
@@ -494,6 +495,7 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
494495}
495496
496497static long heapSize = 0 ; // not thread-local
498+ static const long heapMaxDelta = 1e6 ;
497499static const double heapSoftmaxGrowthThresh = 0.8 ; // grow softmax if >80% max after GC
498500static const double heapSoftmaxGrowthFactor = 1.4 ; // grow softmax by 40%
499501
@@ -521,27 +523,37 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
521523 return err ;
522524}
523525
526+ static long applyHeapDelta (THCState * state ) {
527+ long newHeapSize = THAtomicAddLong (& heapSize , state -> heapDelta ) + state -> heapDelta ;
528+ state -> heapDelta = 0 ;
529+ return newHeapSize ;
530+ }
531+
524532// Here we maintain a dynamic softmax threshold for THC-allocated storages.
525533// When THC heap size goes above this softmax, the GC hook is triggered.
526534// If heap size is above 80% of the softmax after GC, then the softmax is
527535// increased.
528536static void maybeTriggerGC (THCState * state , long curHeapSize ) {
529537 if (state -> cutorchGCFunction != NULL && curHeapSize > state -> heapSoftmax ) {
530538 (state -> cutorchGCFunction )(state -> cutorchGCData );
531- long newHeapSize = THAtomicGetLong (& heapSize );
539+
540+ // ensure heapSize is accurate before updating heapSoftmax
541+ long newHeapSize = applyHeapDelta (state );
542+
532543 if (newHeapSize > state -> heapSoftmax * heapSoftmaxGrowthThresh ) {
533544 state -> heapSoftmax = state -> heapSoftmax * heapSoftmaxGrowthFactor ;
534545 }
535546 }
536547}
537548
538549void THCHeapUpdate (THCState * state , long size ) {
539- long newHeapSize = THAtomicAddLong ( & heapSize , size ) + size ;
540- #ifdef THC_CHECK_HEAP_UPDATE
541- if (newHeapSize < 0 ) {
542- THError ( "Internal error: THC heapSize < 0" ) ;
550+ state -> heapDelta += size ;
551+ // batch updates to global heapSize to minimize thread contention
552+ if (abs ( state -> heapDelta ) < heapMaxDelta ) {
553+ return ;
543554 }
544- #endif
555+
556+ long newHeapSize = applyHeapDelta (state );
545557 if (size > 0 ) {
546558 maybeTriggerGC (state , newHeapSize );
547559 }
0 commit comments