Skip to content

Commit 8341678

Browse files
authored
[Dy2St] Optimize memory allocation for inputs and params (#74095)
1 parent e70aca9 commit 8341678

File tree

3 files changed

+226
-5
lines changed

3 files changed

+226
-5
lines changed

paddle/fluid/pybind/eager_custom_python_api.h

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,45 @@ static PyObject *eager_api_run_program(PyObject *self,
8282
PyObject *kwargs) {
8383
PyThreadState *tstate = nullptr;
8484
try {
85-
auto X = GetTensorListFromArgs("run_program", "X", args, 0, true);
86-
auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true);
85+
auto X_info = GetPyArgumentInfo("run_program", "X", args, 0, true);
86+
TensorListBufferAllocator X_allocator(X_info.second);
87+
auto &X = GetTensorListFromArgsWithBuffer("run_program",
88+
"X",
89+
0,
90+
nullptr,
91+
X_info.first,
92+
X_info.second,
93+
X_allocator);
94+
95+
auto Params_info =
96+
GetPyArgumentInfo("run_program", "Params", args, 1, true);
97+
TensorListBufferAllocator Params_allocator(Params_info.second);
98+
auto &Params = GetTensorListFromArgsWithBuffer("run_program",
99+
"Params",
100+
0,
101+
nullptr,
102+
Params_info.first,
103+
Params_info.second,
104+
Params_allocator);
105+
87106
auto OutScope =
88107
GetScopePtrListFromArgs("run_program", "OutScope", args, 2, false);
89108
const phi::distributed::ProcessMesh *mesh = nullptr;
90109
if (InputsContainDistTensor(&mesh, X, Params)) {
91-
X = GetTensorListFromArgs("run_program", "X", args, 0, true, mesh);
92-
Params =
93-
GetTensorListFromArgs("run_program", "Params", args, 1, true, mesh);
110+
X = GetTensorListFromArgsWithBuffer("run_program",
111+
"X",
112+
0,
113+
nullptr,
114+
X_info.first,
115+
X_info.second,
116+
X_allocator);
117+
Params = GetTensorListFromArgsWithBuffer("run_program",
118+
"Params",
119+
0,
120+
nullptr,
121+
Params_info.first,
122+
Params_info.second,
123+
Params_allocator);
94124
}
95125
framework::AttributeMap attrs;
96126
VLOG(6) << "Start PIR ConstructAttrMapFromPyArgs";

paddle/fluid/pybind/eager_utils.cc

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,154 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs(
24112411
return result;
24122412
}
24132413

2414+
TensorListBufferAllocator::MapType
2415+
TensorListBufferAllocator::s_tensor_vector_map_;
2416+
TensorListBufferAllocator::TensorListBufferAllocator(ssize_t len) : key_(len) {
2417+
MapIterType iter;
2418+
if (key_ == -1) {
2419+
iter = s_tensor_vector_map_.find(-1);
2420+
if (iter == s_tensor_vector_map_.end()) {
2421+
iter = s_tensor_vector_map_.emplace(-1,
2422+
std::make_unique<TensorListBuffer>());
2423+
}
2424+
} else {
2425+
auto range = s_tensor_vector_map_.equal_range(key_);
2426+
for (iter = range.first; iter != range.second; ++iter) {
2427+
if (iter->second->is_available) {
2428+
break;
2429+
}
2430+
}
2431+
if (iter == range.second) {
2432+
iter = s_tensor_vector_map_.emplace(
2433+
key_, std::make_unique<TensorListBuffer>(key_));
2434+
}
2435+
iter->second->is_available = false;
2436+
}
2437+
buffer_ptr_ = iter->second.get();
2438+
}
2439+
2440+
TensorListBufferAllocator::~TensorListBufferAllocator() {
2441+
if (buffer_ptr_) {
2442+
buffer_ptr_->is_available = true;
2443+
2444+
for (auto& tensor : buffer_ptr_->buffer) {
2445+
tensor.reset();
2446+
}
2447+
}
2448+
}
2449+
std::pair<PyObject*, ssize_t> GetPyArgumentInfo(const std::string& op_type,
2450+
const std::string& arg_name,
2451+
PyObject* args,
2452+
ssize_t arg_idx,
2453+
bool dispensable) {
2454+
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);
2455+
ssize_t list_len = 0;
2456+
if (list == nullptr && !dispensable) {
2457+
PADDLE_THROW(common::errors::InvalidArgument(
2458+
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
2459+
"None",
2460+
op_type,
2461+
arg_name,
2462+
arg_idx));
2463+
}
2464+
if (list == nullptr || list == Py_None) {
2465+
list_len = -1;
2466+
} else if (PyList_Check(list)) {
2467+
list_len = PyList_Size(list);
2468+
} else if (PyTuple_Check(list)) {
2469+
list_len = PyTuple_Size(list);
2470+
} else {
2471+
PADDLE_THROW(common::errors::InvalidArgument(
2472+
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
2473+
"%s",
2474+
op_type,
2475+
arg_name,
2476+
arg_idx,
2477+
(reinterpret_cast<PyTypeObject*>(list->ob_type))->tp_name));
2478+
}
2479+
return std::make_pair(list, list_len);
2480+
}
2481+
2482+
std::vector<paddle::Tensor>& GetTensorListFromArgsWithBuffer(
2483+
const std::string& op_type,
2484+
const std::string& arg_name,
2485+
ssize_t arg_idx,
2486+
const phi::distributed::ProcessMesh* mesh,
2487+
PyObject* list,
2488+
ssize_t list_len,
2489+
const TensorListBufferAllocator& allocator) {
2490+
auto& result = allocator.GetAllocatedBuffer();
2491+
2492+
const phi::distributed::ProcessMesh* local_mesh = nullptr;
2493+
ssize_t mesh_start_index = -1;
2494+
2495+
if (PyList_Check(list)) {
2496+
for (Py_ssize_t i = 0; i < list_len; i++) {
2497+
PyObject* tensor_obj = PyList_GetItem(list, i);
2498+
PADDLE_ENFORCE_EQ(
2499+
PyObject_TypeCheck(tensor_obj, p_tensor_type),
2500+
true,
2501+
common::errors::InvalidArgument(
2502+
"%s(): argument '%s' (position %d) must be list of Tensors",
2503+
op_type,
2504+
arg_name,
2505+
arg_idx));
2506+
paddle::Tensor& tensor =
2507+
reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
2508+
if (local_mesh) {
2509+
ConvertToDistTensor(&tensor, local_mesh);
2510+
} else {
2511+
if (tensor.is_dist_tensor()) {
2512+
local_mesh = &(std::static_pointer_cast<phi::distributed::DistTensor>(
2513+
tensor.impl())
2514+
->process_mesh());
2515+
mesh_start_index = i;
2516+
}
2517+
}
2518+
result[i] = tensor;
2519+
}
2520+
for (Py_ssize_t i = 0; i < mesh_start_index; i++) {
2521+
paddle::Tensor& tensor =
2522+
reinterpret_cast<TensorObject*>(PyList_GetItem(list, i))->tensor;
2523+
ConvertToDistTensor(&tensor, local_mesh);
2524+
result[i] = tensor;
2525+
}
2526+
2527+
} else if (PyTuple_Check(list)) {
2528+
for (Py_ssize_t i = 0; i < list_len; i++) {
2529+
PyObject* tensor_obj = PyTuple_GetItem(list, i);
2530+
PADDLE_ENFORCE_EQ(
2531+
PyObject_TypeCheck(tensor_obj, p_tensor_type),
2532+
true,
2533+
common::errors::InvalidArgument(
2534+
"%s(): argument '%s' (position %d) must be list of Tensors",
2535+
op_type,
2536+
arg_name,
2537+
arg_idx));
2538+
paddle::Tensor& tensor =
2539+
reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
2540+
if (local_mesh) {
2541+
ConvertToDistTensor(&tensor, local_mesh);
2542+
} else {
2543+
if (tensor.is_dist_tensor()) {
2544+
local_mesh = &(std::static_pointer_cast<phi::distributed::DistTensor>(
2545+
tensor.impl())
2546+
->process_mesh());
2547+
mesh_start_index = i;
2548+
}
2549+
}
2550+
result[i] = tensor;
2551+
}
2552+
for (Py_ssize_t i = 0; i < mesh_start_index; i++) {
2553+
paddle::Tensor& tensor =
2554+
reinterpret_cast<TensorObject*>(PyTuple_GetItem(list, i))->tensor;
2555+
ConvertToDistTensor(&tensor, local_mesh);
2556+
result[i] = tensor;
2557+
}
2558+
}
2559+
return result;
2560+
}
2561+
24142562
paddle::Place CastPyArg2Place(PyObject* obj,
24152563
const std::string& op_type,
24162564
ssize_t arg_pos) {

paddle/fluid/pybind/eager_utils.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,49 @@ class eager_gil_scoped_release {
446446
PyThreadState* tstate{nullptr};
447447
};
448448

449+
class TensorListBufferAllocator {
450+
private:
451+
struct TensorListBuffer {
452+
bool is_available;
453+
std::vector<paddle::Tensor> buffer;
454+
TensorListBuffer() = default;
455+
explicit TensorListBuffer(ssize_t len) : buffer(len), is_available(true) {}
456+
};
457+
458+
using MapType =
459+
std::unordered_multimap<ssize_t, std::unique_ptr<TensorListBuffer>>;
460+
using MapIterType = MapType::iterator;
461+
462+
ssize_t key_;
463+
TensorListBuffer* buffer_ptr_ = nullptr;
464+
static MapType s_tensor_vector_map_;
465+
466+
public:
467+
explicit TensorListBufferAllocator(ssize_t len);
468+
TensorListBufferAllocator(const TensorListBufferAllocator&) = delete;
469+
TensorListBufferAllocator& operator=(const TensorListBufferAllocator&) =
470+
delete;
471+
~TensorListBufferAllocator();
472+
std::vector<paddle::Tensor>& GetAllocatedBuffer() const {
473+
return buffer_ptr_->buffer;
474+
}
475+
};
476+
477+
std::pair<PyObject*, ssize_t> GetPyArgumentInfo(const std::string& op_type,
478+
const std::string& arg_name,
479+
PyObject* args,
480+
ssize_t arg_idx,
481+
bool dispensable);
482+
483+
std::vector<paddle::Tensor>& GetTensorListFromArgsWithBuffer(
484+
const std::string& op_type,
485+
const std::string& arg_name,
486+
ssize_t arg_idx,
487+
const phi::distributed::ProcessMesh* mesh,
488+
PyObject* list,
489+
ssize_t list_len,
490+
const TensorListBufferAllocator& allocator);
491+
449492
/* ------------------ for SetStaticOpArgPreCastHook ----------------------- */
450493

451494
inline static PyObject* static_op_arg_pre_cast_hook_get();

0 commit comments

Comments
 (0)