Skip to content

Commit 6da111c

Browse files
committed
Merge commit '00843c57c936720b3d17f4c0afaab08dcb52a7cc'
2 parents 568c5c9 + 00843c5 commit 6da111c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

torch/lib/THC/THCGeneral.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,10 +871,16 @@ void THCudaHostRecord(THCState *state, void *ptr)
871871

872872
cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes)
873873
{
874-
size_t cachedBytes = 0;
875874
size_t largestBlock = 0;
875+
return THCudaMemGetInfoCached(state, freeBytes, totalBytes, &largestBlock);
876+
}
877+
878+
cudaError_t THCudaMemGetInfoCached(THCState *state, size_t* freeBytes, size_t* totalBytes, size_t* largestBlock)
879+
{
880+
size_t cachedBytes = 0;
876881
THCDeviceAllocator* allocator = state->cudaDeviceAllocator;
877882

883+
*largestBlock = 0;
878884
/* get info from CUDA first */
879885
cudaError_t ret = cudaMemGetInfo(freeBytes, totalBytes);
880886
if (ret!= cudaSuccess)
@@ -886,11 +892,11 @@ cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalB
886892
return ret;
887893

888894
/* not always true - our optimistic guess here */
889-
largestBlock = *freeBytes;
895+
*largestBlock = *freeBytes;
890896

891897
if (allocator->cacheInfo != NULL)
892-
allocator->cacheInfo(allocator->state, device, &cachedBytes, &largestBlock);
893-
898+
allocator->cacheInfo(allocator->state, device, &cachedBytes, largestBlock);
899+
894900
/* Adjust resulting free bytes number. largesBlock unused for now */
895901
*freeBytes += cachedBytes;
896902
return cudaSuccess;

torch/lib/THC/THCGeneral.h.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ THC_API void THCudaHostFree(THCState *state, void *ptr);
209209
THC_API void THCudaHostRecord(THCState *state, void *ptr);
210210

211211
THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes);
212+
THC_API cudaError_t THCudaMemGetInfoCached(THCState *state, size_t* freeBytes, size_t* totalBytes, size_t* largestBlock);
212213
THC_API void THCSetGCHandler(THCState *state,
213214
void (*torchGCHandlerFunction)(void *data),
214215
void *data );

0 commit comments

Comments
 (0)