Skip to content

Commit be33b7f

Browse files
licy666pytorchmergebot
authored andcommitted
[DeviceMemory] Add Basic Statistics to Device Memory in OpenReg (pytorch#166395)
Implement a complete OpenRegDeviceAllocator with the following enhancements: - Implement memory statistics tracking (allocated/reserved bytes, allocation count) - Track allocation sizes for accurate memory statistics - Refactor DeviceAllocator's inheritance relationship from c10::DeviceAllocator - This change is for further improvement of adding a memory caching function to DeviceMemory Add comprehensive test coverage: - Memory allocation/deallocation tests with statistics validation - Storage operations and tensor-from-blob tests - Multithreading safety tests for concurrent allocations - Gradient tracking and requires_grad compatibility tests Fixes pytorch#166157 Pull Request resolved: pytorch#166395 Approved by: https://github.com/fffrog
1 parent 7a963ff commit be33b7f

File tree

3 files changed

+826
-36
lines changed

3 files changed

+826
-36
lines changed
Lines changed: 269 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,275 @@
11
#include "OpenRegDeviceAllocator.h"
2+
#include "OpenRegFunctions.h"
3+
4+
#include <c10/util/Exception.h>
5+
#include <c10/util/irange.h>
6+
7+
using namespace c10::CachingAllocator;
28

39
namespace c10::openreg {
410

5-
static OpenRegDeviceAllocator global_openreg_alloc;
6-
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc);
11+
constexpr size_t kAggregate = static_cast<size_t>(StatType::AGGREGATE);
12+
13+
14+
DeviceMemoryAllocator::DeviceMemoryAllocator(c10::DeviceIndex device_index)
15+
: device_index_(device_index) {}
16+
17+
void* DeviceMemoryAllocator::malloc(size_t nbytes) {
18+
if (nbytes == 0) {
19+
return nullptr;
20+
}
21+
22+
std::lock_guard<std::recursive_mutex> lock(mutex_);
23+
24+
void* data = nullptr;
25+
auto ret = orMalloc(&data, nbytes);
26+
27+
TORCH_CHECK(
28+
ret == orSuccess && data != nullptr,
29+
"Failed to allocate ",
30+
nbytes,
31+
" bytes on openreg device ",
32+
device_index_,
33+
". ",
34+
"Allocated: ",
35+
stats_.allocated_bytes[0].current,
36+
" bytes, ",
37+
"Reserved: ",
38+
stats_.reserved_bytes[0].current,
39+
" bytes");
40+
41+
// Track allocation size for proper deallocation statistics
42+
allocation_sizes_[data] = nbytes;
43+
44+
// Update statistics
45+
stats_.allocated_bytes[kAggregate].increase(nbytes);
46+
stats_.reserved_bytes[kAggregate].increase(nbytes);
47+
stats_.num_device_alloc++;
48+
49+
return data;
50+
}
51+
52+
void DeviceMemoryAllocator::free(void* ptr) {
53+
if (!ptr) {
54+
return;
55+
}
56+
57+
std::lock_guard<std::recursive_mutex> lock(mutex_);
58+
59+
auto ret = orFree(ptr);
60+
61+
if (ret == orSuccess) {
62+
auto it = allocation_sizes_.find(ptr);
63+
if (it != allocation_sizes_.end()) {
64+
size_t nbytes = it->second;
65+
66+
stats_.allocated_bytes[kAggregate].decrease(nbytes);
67+
stats_.reserved_bytes[kAggregate].decrease(nbytes);
68+
stats_.num_device_free++;
69+
70+
allocation_sizes_.erase(it);
71+
} else {
72+
TORCH_WARN(
73+
"Successfully freed OpenReg memory pointer ",
74+
ptr,
75+
" on device ",
76+
device_index_,
77+
" that was not tracked by the allocator. "
78+
"Statistics may be inaccurate.");
79+
}
80+
} else {
81+
// orFree failed
82+
auto it = allocation_sizes_.find(ptr);
83+
if (it != allocation_sizes_.end()) {
84+
TORCH_WARN(
85+
"orFree failed for tracked pointer ",
86+
ptr,
87+
" with size ",
88+
it->second,
89+
" bytes on device ",
90+
device_index_,
91+
". Return code: ",
92+
ret,
93+
". Keeping tracking record - this may indicate a double-free or invalid pointer.");
94+
} else {
95+
TORCH_WARN(
96+
"orFree failed for untracked pointer ",
97+
ptr,
98+
" on device ",
99+
device_index_,
100+
". Return code: ",
101+
ret,
102+
". This likely indicates a double-free or invalid pointer.");
103+
}
104+
}
105+
}
106+
107+
c10::CachingDeviceAllocator::DeviceStats DeviceMemoryAllocator::getStats() {
108+
std::lock_guard<std::recursive_mutex> lock(mutex_);
109+
return stats_;
110+
}
111+
112+
void DeviceMemoryAllocator::resetAccumulatedStats() {
113+
std::lock_guard<std::recursive_mutex> lock(mutex_);
114+
115+
// Reset accumulated statistics for all StatTypes
116+
for (const auto stat_type :
117+
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
118+
stats_.allocated_bytes[stat_type].reset_accumulated();
119+
stats_.reserved_bytes[stat_type].reset_accumulated();
120+
stats_.active_bytes[stat_type].reset_accumulated();
121+
stats_.inactive_split_bytes[stat_type].reset_accumulated();
122+
stats_.requested_bytes[stat_type].reset_accumulated();
123+
}
124+
125+
stats_.num_alloc_retries = 0;
126+
stats_.num_ooms = 0;
127+
stats_.num_sync_all_streams = 0;
128+
stats_.num_device_alloc = 0;
129+
stats_.num_device_free = 0;
130+
}
131+
132+
void DeviceMemoryAllocator::resetPeakStats() {
133+
std::lock_guard<std::recursive_mutex> lock(mutex_);
134+
135+
// Reset peak statistics for all StatTypes
136+
for (const auto stat_type :
137+
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
138+
stats_.allocated_bytes[stat_type].reset_peak();
139+
stats_.reserved_bytes[stat_type].reset_peak();
140+
stats_.active_bytes[stat_type].reset_peak();
141+
stats_.inactive_split_bytes[stat_type].reset_peak();
142+
stats_.requested_bytes[stat_type].reset_peak();
143+
}
144+
145+
stats_.oversize_allocations.reset_peak();
146+
stats_.oversize_segments.reset_peak();
147+
}
148+
149+
namespace {
150+
151+
OpenRegDeviceAllocator g_allocator;
152+
153+
void deleteOpenRegMemory(void* ptr) {
154+
g_allocator.freeMemory(ptr);
155+
}
156+
157+
}
158+
159+
OpenRegDeviceAllocator::OpenRegDeviceAllocator() {
160+
std::lock_guard<std::recursive_mutex> lock(mutex_);
161+
const auto device_count = c10::openreg::device_count();
162+
device_allocators_.resize(device_count);
163+
for (const auto i : c10::irange(device_count)) {
164+
device_allocators_[i] = std::make_unique<DeviceMemoryAllocator>(i);
165+
}
166+
}
167+
168+
169+
at::DataPtr OpenRegDeviceAllocator::allocate(size_t nbytes) {
170+
int current_device_index = -1;
171+
auto ret = orGetDevice(&current_device_index);
172+
TORCH_CHECK(ret == orSuccess, "Failed to get current OpenReg device");
173+
174+
auto curr_device =
175+
c10::Device(c10::DeviceType::PrivateUse1, current_device_index);
176+
177+
void* data = nullptr;
178+
if (nbytes > 0) {
179+
// Allocate memory via device-specific allocator
180+
data = device_allocators_[current_device_index]->malloc(nbytes);
181+
182+
// Track which device owns this pointer
183+
std::lock_guard<std::recursive_mutex> lock(mutex_);
184+
allocated_blocks_[data] = current_device_index;
185+
}
186+
187+
return {data, data, &deleteOpenRegMemory, curr_device};
188+
}
189+
190+
at::DeleterFnPtr OpenRegDeviceAllocator::raw_deleter() const {
191+
return &deleteOpenRegMemory;
192+
}
193+
194+
void OpenRegDeviceAllocator::copy_data(
195+
void* dest,
196+
const void* src,
197+
std::size_t count) const {
198+
auto ret = orMemcpy(dest, src, count, orMemcpyDeviceToDevice);
199+
TORCH_CHECK(
200+
ret == orSuccess, "Failed to copy ", count, " bytes on openreg device");
201+
}
202+
203+
bool OpenRegDeviceAllocator::initialized() {
204+
std::lock_guard<std::recursive_mutex> lock(mutex_);
205+
return !device_allocators_.empty();
206+
}
207+
208+
void OpenRegDeviceAllocator::freeMemory(void* ptr) {
209+
if (!ptr) {
210+
return;
211+
}
212+
213+
// Try to find which device owns this pointer
214+
c10::DeviceIndex device_index = -1;
215+
bool found_in_map = false;
216+
217+
{
218+
std::lock_guard<std::recursive_mutex> lock(mutex_);
219+
auto it = allocated_blocks_.find(ptr);
220+
if (it != allocated_blocks_.end()) {
221+
device_index = it->second;
222+
allocated_blocks_.erase(it);
223+
found_in_map = true;
224+
}
225+
}
226+
227+
if (found_in_map) {
228+
// Pointer was tracked - free via device-specific allocator with stats
229+
device_allocators_[device_index]->free(ptr);
230+
} else {
231+
// Pointer not tracked - might be already freed by storage or other path
232+
// Try to free it directly via orFree without updating statistics
233+
auto ret = orFree(ptr);
234+
235+
// Only warn if orFree actually failed (not just "not found")
236+
// In OpenReg's case, orFree returns orErrorUnknown if pointer not in registry
237+
// which is expected for already-freed memory
238+
if (ret != orSuccess && ret != orErrorUnknown) {
239+
TORCH_WARN(
240+
"orFree failed for untracked OpenReg memory pointer ",
241+
ptr,
242+
". Error code: ", ret);
243+
}
244+
}
245+
}
246+
247+
c10::CachingDeviceAllocator::DeviceStats OpenRegDeviceAllocator::
248+
getDeviceStats(c10::DeviceIndex device) {
249+
return device_allocators_[device]->getStats();
250+
}
251+
252+
void OpenRegDeviceAllocator::resetAccumulatedStats(c10::DeviceIndex device) {
253+
device_allocators_[device]->resetAccumulatedStats();
254+
}
255+
256+
void OpenRegDeviceAllocator::resetPeakStats(c10::DeviceIndex device) {
257+
device_allocators_[device]->resetPeakStats();
258+
}
259+
260+
void OpenRegDeviceAllocator::emptyCache(MempoolId_t mempool_id) {
261+
// OpenReg doesn't implement caching yet
262+
// TODO: When caching is implemented, release all free blocks here
263+
}
264+
265+
void OpenRegDeviceAllocator::recordStream(
266+
const DataPtr& ptr,
267+
c10::Stream stream) {
268+
// OpenReg doesn't track stream usage yet
269+
// TODO: When stream support is added, track which streams are using this pointer
270+
}
271+
// ============ Global Registration ============
272+
273+
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &g_allocator);
7274

8275
} // namespace c10::openreg
Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,78 @@
1-
#include <ATen/core/CachingHostAllocator.h>
1+
#pragma once
22

33
#include <c10/core/Allocator.h>
4+
#include <c10/core/CachingDeviceAllocator.h>
45
#include <c10/core/Device.h>
6+
#include <c10/util/flat_hash_map.h>
57

68
#include <include/openreg.h>
79

10+
#include <memory>
11+
#include <mutex>
12+
#include <unordered_map>
13+
#include <vector>
14+
815
namespace c10::openreg {
9-
struct OpenRegDeviceAllocator final : at::Allocator {
10-
OpenRegDeviceAllocator() = default;
11-
12-
static void ReportAndDelete(void* ptr) {
13-
if (!ptr) {
14-
return;
15-
}
16-
orFreeHost(ptr);
17-
}
18-
19-
at::DataPtr allocate(size_t nbytes) override {
20-
int current_device_index = -1;
21-
orGetDevice(&current_device_index);
22-
23-
auto curr_device =
24-
c10::Device(c10::DeviceType::PrivateUse1, current_device_index);
25-
void* data = nullptr;
26-
if (nbytes > 0) {
27-
orMalloc(&data, nbytes);
28-
TORCH_CHECK(
29-
data, "Failed to allocator ", nbytes, " bytes on openreg device.");
30-
}
31-
return {data, data, &ReportAndDelete, curr_device};
32-
}
33-
34-
at::DeleterFnPtr raw_deleter() const override {
35-
return &ReportAndDelete;
36-
}
37-
38-
void copy_data(void* dest, const void* src, std::size_t count) const final {
39-
orMemcpy(dest, src, count, orMemcpyDeviceToDevice);
40-
}
16+
17+
class DeviceMemoryAllocator {
18+
public:
19+
explicit DeviceMemoryAllocator(c10::DeviceIndex device_index);
20+
21+
DeviceMemoryAllocator(const DeviceMemoryAllocator&) = delete;
22+
DeviceMemoryAllocator& operator=(const DeviceMemoryAllocator&) = delete;
23+
24+
void* malloc(size_t nbytes);
25+
26+
void free(void* ptr);
27+
28+
c10::CachingDeviceAllocator::DeviceStats getStats();
29+
30+
void resetAccumulatedStats();
31+
32+
void resetPeakStats();
33+
34+
private:
35+
c10::DeviceIndex device_index_;
36+
37+
c10::CachingDeviceAllocator::DeviceStats stats_;
38+
39+
std::unordered_map<void*, size_t> allocation_sizes_;
40+
41+
std::recursive_mutex mutex_;
42+
};
43+
44+
45+
class OpenRegDeviceAllocator final : public c10::DeviceAllocator {
46+
public:
47+
OpenRegDeviceAllocator();
48+
49+
at::DataPtr allocate(size_t nbytes) override;
50+
at::DeleterFnPtr raw_deleter() const override;
51+
void copy_data(void* dest, const void* src, std::size_t count) const final;
52+
53+
54+
bool initialized() override;
55+
void emptyCache(MempoolId_t mempool_id = {0, 0}) override;
56+
void recordStream(const DataPtr& ptr, c10::Stream stream) override;
57+
c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
58+
c10::DeviceIndex device) override;
59+
void resetAccumulatedStats(c10::DeviceIndex device) override;
60+
void resetPeakStats(c10::DeviceIndex device) override;
61+
62+
63+
void freeMemory(void* ptr);
64+
65+
private:
66+
67+
// Per-device allocators
68+
std::vector<std::unique_ptr<DeviceMemoryAllocator>> device_allocators_;
69+
70+
// Global mapping from pointer to device index
71+
std::recursive_mutex mutex_;
72+
ska::flat_hash_map<void*, c10::DeviceIndex> allocated_blocks_;
4173
};
4274

43-
} // namespace c10::openreg
75+
76+
77+
78+
}

0 commit comments

Comments
 (0)