11#include " THCCachingAllocator.h"
22
33#include < cuda_runtime_api.h>
4+ #include < deque>
45#include < map>
56#include < memory>
67#include < mutex>
1718// split. If no block is found, the allocator will delegate to cudaMalloc.
1819// - If the cudaMalloc fails, the allocator will free all cached blocks that
1920// are not split and retry the allocation.
20- // - Large (>1MB) and small allocation requestss are handled separately. Large
21+ // - Large (>1MB) and small allocation requests are handled separately. Large
2122// allocation requests can be filled by a cudaMalloc call of the exact size.
2223// Small requests will allocate and split a 1MB buffer, if necessary.
2324//
2627// launches. The programmer must insert the proper synchronization if memory
2728// segments are used from multiple streams.
2829//
30+ // The library provides a recordStream() function to help insert the correct
31+ // synchronization when allocations are used on multiple streams. This will
32+ // ensure that the block is not reused before each recorded stream completes
33+ // work.
34+ //
2935
3036
3137namespace {
3238
39+ typedef std::shared_ptr<THCStream> THCStreamPtr;
40+ typedef std::set<THCStreamPtr> stream_set;
41+
3342const size_t kRoundSmall = 512 ; // round up small allocs to 512 bytes
3443const size_t kRoundLarge = 131072 ; // round up large allocs to 128 KiB
3544const size_t kSmallAlloc = 1048576 ; // largest "small" allocation is 1 MiB
3645
3746struct Block {
38- int device; // gpu
39- cudaStream_t stream; // allocation stream
40- size_t size; // block size in bytes
41- char * ptr; // memory address
42- bool allocated; // in-use flag
43- Block* prev; // prev block if split from a larger allocation
44- Block* next; // next block if split from a larger allocation
47+ int device; // gpu
48+ cudaStream_t stream; // allocation stream
49+ stream_set stream_uses; // streams on which the block was used
50+ size_t size; // block size in bytes
51+ char * ptr; // memory address
52+ bool allocated; // in-use flag
53+ Block* prev; // prev block if split from a larger allocation
54+ Block* next; // next block if split from a larger allocation
55+ int event_count; // number of outstanding CUDA events
4556
4657 Block (int device, cudaStream_t stream, size_t size, char * ptr=NULL ) :
47- device (device), stream(stream), size(size ), ptr(ptr ), allocated( 0 ),
48- prev (NULL ), next(NULL ) { }
58+ device (device), stream(stream), stream_uses( ), size(size ), ptr(ptr ),
59+ allocated ( 0 ), prev(NULL ), next(NULL ), event_count( 0 ) { }
4960};
5061
5162static bool BlockComparator (const Block* a, const Block* b)
@@ -84,6 +95,9 @@ struct THCCachingAllocator
8495 // allocated blocks by device pointer
8596 std::unordered_map<void *, Block*> allocated_blocks;
8697
98+ // outstanding cuda events
99+ std::deque<std::pair<cudaEvent_t, Block*>> cuda_events;
100+
87101 THCCachingAllocator () :
88102 large_blocks (BlockComparator),
89103 small_blocks (BlockComparator) {}
@@ -99,6 +113,11 @@ struct THCCachingAllocator
99113 return err;
100114 }
101115
116+ err = process_events ();
117+ if (err != cudaSuccess) {
118+ return err;
119+ }
120+
102121 size = round_size (size);
103122 bool small = size <= kSmallAlloc ;
104123
@@ -159,15 +178,13 @@ struct THCCachingAllocator
159178
160179 Block* block = it->second ;
161180 allocated_blocks.erase (it);
162-
163- bool small = block->size <= kSmallAlloc ;
164- auto & free_blocks = small ? large_blocks : small_blocks;
165- try_merge_blocks (block, block->prev , free_blocks);
166- try_merge_blocks (block, block->next , free_blocks);
167-
168181 block->allocated = false ;
169- free_blocks.insert (block);
170182
183+ if (!block->stream_uses .empty ()) {
184+ return insert_events (block);
185+ }
186+
187+ free_block (block);
171188 return cudaSuccess;
172189 }
173190
@@ -229,6 +246,33 @@ struct THCCachingAllocator
229246 cacheInfoAux (small_blocks, dev_id, total, largest);
230247 }
231248
249+ void recordStream (void * ptr, THCStream* stream)
250+ {
251+ std::lock_guard<std::mutex> lock (mutex);
252+ Block* block = find_allocated_block (ptr);
253+ if (!block) {
254+ THError (" invalid device pointer: %p" , ptr);
255+ }
256+ if (stream->stream == block->stream ) {
257+ // ignore uses on the allocation stream, since those don't require any
258+ // special synchronization
259+ return ;
260+ }
261+ THCStream_retain (stream);
262+ block->stream_uses .insert (THCStreamPtr (stream, &THCStream_free));
263+ }
264+
265+ /* * moves a block into the free block list */
266+ void free_block (Block* block)
267+ {
268+ THAssert (!block->allocated && block->event_count == 0 );
269+ bool small = block->size <= kSmallAlloc ;
270+ auto & free_blocks = small ? large_blocks : small_blocks;
271+ try_merge_blocks (block, block->prev , free_blocks);
272+ try_merge_blocks (block, block->next , free_blocks);
273+ free_blocks.insert (block);
274+ }
275+
232276 /* * combine previously split blocks */
233277 void try_merge_blocks (Block* dst, Block* src, FreeBlocks& free_blocks)
234278 {
@@ -332,6 +376,68 @@ struct THCCachingAllocator
332376 }
333377 return it->second ;
334378 }
379+
380+ cudaError_t insert_events (Block* block)
381+ {
382+ cudaError_t err;
383+
384+ int prev_device;
385+ err = cudaGetDevice (&prev_device);
386+ if (err != cudaSuccess) return err;
387+
388+ std::set<THCStreamPtr> streams (std::move (block->stream_uses ));
389+ for (auto it = streams.begin (); it != streams.end (); ++it) {
390+ auto & stream = *it;
391+
392+ err = cudaSetDevice (stream->device );
393+ if (err != cudaSuccess) break ;
394+
395+ cudaEvent_t event;
396+ err = cudaEventCreateWithFlags (&event, cudaEventDisableTiming);
397+ if (err != cudaSuccess) break ;
398+
399+ err = cudaEventRecord (event, stream->stream );
400+ if (err != cudaSuccess) break ;
401+
402+ block->event_count ++;
403+ cuda_events.emplace_back (event, block);
404+ }
405+
406+ cudaSetDevice (prev_device);
407+ return err;
408+ }
409+
410+ cudaError_t process_events ()
411+ {
412+ // Process outstanding cudaEvents. Events that are completed are removed
413+ // from the queue, and the 'event_count' for the corresponding allocation
414+ // is decremented. Stops at the first event which has not been completed.
415+ // Since events on different devices or streams may occur out of order,
416+ // the processing of some events may be delayed.
417+ while (!cuda_events.empty ()) {
418+ auto & e = cuda_events.front ();
419+ cudaEvent_t event = e.first ;
420+ Block* block = e.second ;
421+
422+ cudaError_t err = cudaEventQuery (event);
423+ if (err == cudaErrorNotReady) {
424+ break ;
425+ } else if (err != cudaSuccess) {
426+ return err;
427+ }
428+ err = cudaEventDestroy (event);
429+ if (err != cudaSuccess) {
430+ return err;
431+ }
432+
433+ block->event_count --;
434+ if (block->event_count == 0 ) {
435+ free_block (block);
436+ }
437+ cuda_events.pop_front ();
438+ }
439+ return cudaSuccess;
440+ }
335441};
336442
337443static cudaError_t THCCachingAllocator_malloc (void * ctx, void ** ptr, size_t size, cudaStream_t stream)
@@ -379,6 +485,11 @@ THC_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size)
379485 return caching_allocator.getBaseAllocation (ptr, size);
380486}
381487
488+ THC_API void THCCachingAllocator_recordStream (void *ptr, THCStream* stream)
489+ {
490+ caching_allocator.recordStream (ptr, stream);
491+ }
492+
382493THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex ()
383494{
384495 return &caching_allocator.cuda_free_mutex ;
0 commit comments