Skip to content

Commit d60fe8d

Browse files
committed
Batch updates to global heapSize
1 parent d006752 commit d60fe8d

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

THCGeneral.c

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ void THCudaInit(THCState* state)
7474

7575
state->cutorchGCFunction = NULL;
7676
state->cutorchGCData = NULL;
77-
state->heapSoftmax = 300000000; // 300MB, adjusted upward dynamically
77+
state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
78+
state->heapDelta = 0;
7879
}
7980

8081
void THCudaShutdown(THCState* state)
@@ -494,6 +495,7 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
494495
}
495496

496497
static long heapSize = 0; // not thread-local
498+
static const long heapMaxDelta = 1e6;
497499
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
498500
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%
499501

@@ -521,27 +523,37 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
521523
return err;
522524
}
523525

526+
static long applyHeapDelta(THCState *state) {
527+
long newHeapSize = THAtomicAddLong(&heapSize, state->heapDelta) + state->heapDelta;
528+
state->heapDelta = 0;
529+
return newHeapSize;
530+
}
531+
524532
// Here we maintain a dynamic softmax threshold for THC-allocated storages.
525533
// When THC heap size goes above this softmax, the GC hook is triggered.
526534
// If heap size is above 80% of the softmax after GC, then the softmax is
527535
// increased.
528536
static void maybeTriggerGC(THCState *state, long curHeapSize) {
529537
if (state->cutorchGCFunction != NULL && curHeapSize > state->heapSoftmax) {
530538
(state->cutorchGCFunction)(state->cutorchGCData);
531-
long newHeapSize = THAtomicGetLong(&heapSize);
539+
540+
// ensure heapSize is accurate before updating heapSoftmax
541+
long newHeapSize = applyHeapDelta(state);
542+
532543
if (newHeapSize > state->heapSoftmax * heapSoftmaxGrowthThresh) {
533544
state->heapSoftmax = state->heapSoftmax * heapSoftmaxGrowthFactor;
534545
}
535546
}
536547
}
537548

538549
void THCHeapUpdate(THCState *state, long size) {
539-
long newHeapSize = THAtomicAddLong(&heapSize, size) + size;
540-
#ifdef THC_CHECK_HEAP_UPDATE
541-
if (newHeapSize < 0) {
542-
THError("Internal error: THC heapSize < 0");
550+
state->heapDelta += size;
551+
// batch updates to global heapSize to minimize thread contention
552+
if (abs(state->heapDelta) < heapMaxDelta) {
553+
return;
543554
}
544-
#endif
555+
556+
long newHeapSize = applyHeapDelta(state);
545557
if (size > 0) {
546558
maybeTriggerGC(state, newHeapSize);
547559
}

THCGeneral.h.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ typedef struct THCState
7878
void (*cutorchGCFunction)(void *data);
7979
void *cutorchGCData;
8080
long heapSoftmax;
81+
long heapDelta;
8182
} THCState;
8283

8384
THC_API void THCudaInit(THCState* state);

0 commit comments

Comments
 (0)