11#include "THGeneral.h"
2+ #include "THAtomic.h"
23
34#ifndef TH_HAVE_THREAD
45#define __thread
56#endif
67
7- #if defined(TH_DISABLE_HEAP_TRACKING )
8- #elif (defined(__unix ) || defined(_WIN32 ))
8+ #if (defined(__unix ) || defined(_WIN32 ))
99#include <malloc.h>
1010#elif defined(__APPLE__ )
1111#include <malloc/malloc.h>
@@ -101,8 +101,10 @@ void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber,
101101
102102static __thread void (* torchGCFunction )(void * data ) = NULL ;
103103static __thread void * torchGCData ;
104- static __thread long torchHeapSize = 0 ;
105- static __thread long torchHeapSizeSoftMax = 300000000 ; // 300MB, adjusted upward dynamically
104+ static long heapSize = 0 ;
105+ static __thread long heapSoftmax = 300000000 ; // 300MB, adjusted upward dynamically
106+ static const double heapSoftmaxGrowthThresh = 0.8 ; // grow softmax if >80% max after GC
107+ static const double heapSoftmaxGrowthFactor = 1.4 ; // grow softmax by 40%
106108
107109/* Optional hook for integrating with a garbage-collected frontend.
108110 *
@@ -121,9 +123,7 @@ void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data )
121123}
122124
123125static long getAllocSize (void * ptr ) {
124- #if defined(TH_DISABLE_HEAP_TRACKING )
125- return 0 ;
126- #elif defined(__unix )
126+ #if defined(__unix )
127127 return malloc_usable_size (ptr );
128128#elif defined(__APPLE__ )
129129 return malloc_size (ptr );
@@ -138,20 +138,29 @@ static long getAllocSize(void *ptr) {
138138 * (2) if post-GC heap size exceeds 80% of the soft max, increase the
139139 * soft max by 40%
140140 */
141- static void maybeTriggerGC () {
142- if (torchGCFunction && torchHeapSize > torchHeapSizeSoftMax ) {
141+ static void maybeTriggerGC (long curHeapSize ) {
142+ if (torchGCFunction && curHeapSize > heapSoftmax ) {
143143 torchGCFunction (torchGCData );
144- if (torchHeapSize > torchHeapSizeSoftMax * 0.8 ) {
145- torchHeapSizeSoftMax = torchHeapSizeSoftMax * 1.4 ;
144+ long newHeapSize = THAtomicGetLong (& heapSize );
145+ if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh ) {
146+ heapSoftmax = heapSoftmax * heapSoftmaxGrowthFactor ;
146147 }
147148 }
148149}
149150
150151// hooks into the TH heap tracking
151152void THHeapUpdate (long size ) {
152- torchHeapSize += size ;
153- if (size > 0 )
154- maybeTriggerGC ();
153+ long newHeapSize = THAtomicAddLong (& heapSize , size ) + size ;
154+
155+ # ifdef TH_CHECK_HEAP_UPDATE
156+ if (newHeapSize < 0 ) {
157+ THError ("Torch heap size <0 ?" );
158+ }
159+ #endif
160+
161+ if (size > 0 ) {
162+ maybeTriggerGC (newHeapSize );
163+ }
155164}
156165
157166static void * THAllocInternal (long size )
0 commit comments