Skip to content

Commit 0f872ed

Browse files
committed
Add THCCachingAllocator_recordStream()
This is similar to THCCachingHostAllocator_recordEvent() but on CUDA allocations. It's useful for overlapping copies with computation. The workflow is approximately: 0. allocate dst tensor on copy stream 1. copy from CPU to GPU on copy stream 2. synchronize the main stream with the copy stream via cudaStreamWaitEvent 3. THCCachingAllocator_recordStream(dst, main_stream) The recordStream() call is necessary to prevent the dst tensor from begin reused on the copy stream before the main stream finishes work. Previously, you would need to insert a second cudaStreamWaitEvent before dst is freed to force the copy stream to wait on the main stream.
1 parent aec182a commit 0f872ed

File tree

2 files changed

+130
-17
lines changed

2 files changed

+130
-17
lines changed

THCCachingAllocator.cpp

Lines changed: 128 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "THCCachingAllocator.h"
22

33
#include <cuda_runtime_api.h>
4+
#include <deque>
45
#include <map>
56
#include <memory>
67
#include <mutex>
@@ -17,7 +18,7 @@
1718
// split. If no block is found, the allocator will delegate to cudaMalloc.
1819
// - If the cudaMalloc fails, the allocator will free all cached blocks that
1920
// are not split and retry the allocation.
20-
// - Large (>1MB) and small allocation requestss are handled separately. Large
21+
// - Large (>1MB) and small allocation requests are handled separately. Large
2122
// allocation requests can be filled by a cudaMalloc call of the exact size.
2223
// Small requests will allocate and split a 1MB buffer, if necessary.
2324
//
@@ -26,26 +27,36 @@
2627
// launches. The programmer must insert the proper synchronization if memory
2728
// segments are used from multiple streams.
2829
//
30+
// The library provides a recordStream() function to help insert the correct
31+
// synchronization when allocations are used on multiple streams. This will
32+
// ensure that the block is not reused before each recorded stream completes
33+
// work.
34+
//
2935

3036

3137
namespace {
3238

39+
typedef std::shared_ptr<THCStream> THCStreamPtr;
40+
typedef std::set<THCStreamPtr> stream_set;
41+
3342
const size_t kRoundSmall = 512; // round up small allocs to 512 bytes
3443
const size_t kRoundLarge = 131072; // round up large allocs to 128 KiB
3544
const size_t kSmallAlloc = 1048576; // largest "small" allocation is 1 MiB
3645

3746
struct Block {
38-
int device; // gpu
39-
cudaStream_t stream; // allocation stream
40-
size_t size; // block size in bytes
41-
char* ptr; // memory address
42-
bool allocated; // in-use flag
43-
Block* prev; // prev block if split from a larger allocation
44-
Block* next; // next block if split from a larger allocation
47+
int device; // gpu
48+
cudaStream_t stream; // allocation stream
49+
stream_set stream_uses; // streams on which the block was used
50+
size_t size; // block size in bytes
51+
char* ptr; // memory address
52+
bool allocated; // in-use flag
53+
Block* prev; // prev block if split from a larger allocation
54+
Block* next; // next block if split from a larger allocation
55+
int event_count; // number of outstanding CUDA events
4556

4657
Block(int device, cudaStream_t stream, size_t size, char* ptr=NULL) :
47-
device(device), stream(stream), size(size), ptr(ptr), allocated(0),
48-
prev(NULL), next(NULL) { }
58+
device(device), stream(stream), stream_uses(), size(size), ptr(ptr),
59+
allocated(0), prev(NULL), next(NULL), event_count(0) { }
4960
};
5061

5162
static bool BlockComparator(const Block* a, const Block* b)
@@ -84,6 +95,9 @@ struct THCCachingAllocator
8495
// allocated blocks by device pointer
8596
std::unordered_map<void*, Block*> allocated_blocks;
8697

98+
// outstanding cuda events
99+
std::deque<std::pair<cudaEvent_t, Block*>> cuda_events;
100+
87101
THCCachingAllocator() :
88102
large_blocks(BlockComparator),
89103
small_blocks(BlockComparator) {}
@@ -99,6 +113,11 @@ struct THCCachingAllocator
99113
return err;
100114
}
101115

116+
err = process_events();
117+
if (err != cudaSuccess) {
118+
return err;
119+
}
120+
102121
size = round_size(size);
103122
bool small = size <= kSmallAlloc;
104123

@@ -159,15 +178,13 @@ struct THCCachingAllocator
159178

160179
Block* block = it->second;
161180
allocated_blocks.erase(it);
162-
163-
bool small = block->size <= kSmallAlloc;
164-
auto& free_blocks = small ? large_blocks : small_blocks;
165-
try_merge_blocks(block, block->prev, free_blocks);
166-
try_merge_blocks(block, block->next, free_blocks);
167-
168181
block->allocated = false;
169-
free_blocks.insert(block);
170182

183+
if (!block->stream_uses.empty()) {
184+
return insert_events(block);
185+
}
186+
187+
free_block(block);
171188
return cudaSuccess;
172189
}
173190

@@ -229,6 +246,33 @@ struct THCCachingAllocator
229246
cacheInfoAux(small_blocks, dev_id, total, largest);
230247
}
231248

249+
void recordStream(void* ptr, THCStream* stream)
250+
{
251+
std::lock_guard<std::mutex> lock(mutex);
252+
Block* block = find_allocated_block(ptr);
253+
if (!block) {
254+
THError("invalid device pointer: %p", ptr);
255+
}
256+
if (stream->stream == block->stream) {
257+
// ignore uses on the allocation stream, since those don't require any
258+
// special synchronization
259+
return;
260+
}
261+
THCStream_retain(stream);
262+
block->stream_uses.insert(THCStreamPtr(stream, &THCStream_free));
263+
}
264+
265+
/** moves a block into the free block list */
266+
void free_block(Block* block)
267+
{
268+
THAssert(!block->allocated && block->event_count == 0);
269+
bool small = block->size <= kSmallAlloc;
270+
auto& free_blocks = small ? large_blocks : small_blocks;
271+
try_merge_blocks(block, block->prev, free_blocks);
272+
try_merge_blocks(block, block->next, free_blocks);
273+
free_blocks.insert(block);
274+
}
275+
232276
/** combine previously split blocks */
233277
void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
234278
{
@@ -332,6 +376,68 @@ struct THCCachingAllocator
332376
}
333377
return it->second;
334378
}
379+
380+
cudaError_t insert_events(Block* block)
381+
{
382+
cudaError_t err;
383+
384+
int prev_device;
385+
err = cudaGetDevice(&prev_device);
386+
if (err != cudaSuccess) return err;
387+
388+
std::set<THCStreamPtr> streams(std::move(block->stream_uses));
389+
for (auto it = streams.begin(); it != streams.end(); ++it) {
390+
auto& stream = *it;
391+
392+
err = cudaSetDevice(stream->device);
393+
if (err != cudaSuccess) break;
394+
395+
cudaEvent_t event;
396+
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
397+
if (err != cudaSuccess) break;
398+
399+
err = cudaEventRecord(event, stream->stream);
400+
if (err != cudaSuccess) break;
401+
402+
block->event_count++;
403+
cuda_events.emplace_back(event, block);
404+
}
405+
406+
cudaSetDevice(prev_device);
407+
return err;
408+
}
409+
410+
cudaError_t process_events()
411+
{
412+
// Process outstanding cudaEvents. Events that are completed are removed
413+
// from the queue, and the 'event_count' for the corresponding allocation
414+
// is decremented. Stops at the first event which has not been completed.
415+
// Since events on different devices or streams may occur out of order,
416+
// the processing of some events may be delayed.
417+
while (!cuda_events.empty()) {
418+
auto& e = cuda_events.front();
419+
cudaEvent_t event = e.first;
420+
Block* block = e.second;
421+
422+
cudaError_t err = cudaEventQuery(event);
423+
if (err == cudaErrorNotReady) {
424+
break;
425+
} else if (err != cudaSuccess) {
426+
return err;
427+
}
428+
err = cudaEventDestroy(event);
429+
if (err != cudaSuccess) {
430+
return err;
431+
}
432+
433+
block->event_count--;
434+
if (block->event_count == 0) {
435+
free_block(block);
436+
}
437+
cuda_events.pop_front();
438+
}
439+
return cudaSuccess;
440+
}
335441
};
336442

337443
static cudaError_t THCCachingAllocator_malloc(void* ctx, void** ptr, size_t size, cudaStream_t stream)
@@ -379,6 +485,11 @@ THC_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size)
379485
return caching_allocator.getBaseAllocation(ptr, size);
380486
}
381487

488+
THC_API void THCCachingAllocator_recordStream(void *ptr, THCStream* stream)
489+
{
490+
caching_allocator.recordStream(ptr, stream);
491+
}
492+
382493
THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex()
383494
{
384495
return &caching_allocator.cuda_free_mutex;

THCCachingAllocator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
#endif
77

88
#include "THCGeneral.h"
9+
#include "THCStream.h"
910

1011
THC_API THCDeviceAllocator* THCCachingAllocator_get(void);
1112
THC_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size);
13+
THC_API void THCCachingAllocator_recordStream(void *ptr, THCStream* stream);
1214

1315
#if __cplusplus >= 201103L
1416
THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();

0 commit comments

Comments
 (0)