@@ -2411,6 +2411,154 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs(
24112411 return result;
24122412}
24132413
2414+ TensorListBufferAllocator::MapType
2415+ TensorListBufferAllocator::s_tensor_vector_map_;
2416+ TensorListBufferAllocator::TensorListBufferAllocator (ssize_t len) : key_(len) {
2417+ MapIterType iter;
2418+ if (key_ == -1 ) {
2419+ iter = s_tensor_vector_map_.find (-1 );
2420+ if (iter == s_tensor_vector_map_.end ()) {
2421+ iter = s_tensor_vector_map_.emplace (-1 ,
2422+ std::make_unique<TensorListBuffer>());
2423+ }
2424+ } else {
2425+ auto range = s_tensor_vector_map_.equal_range (key_);
2426+ for (iter = range.first ; iter != range.second ; ++iter) {
2427+ if (iter->second ->is_available ) {
2428+ break ;
2429+ }
2430+ }
2431+ if (iter == range.second ) {
2432+ iter = s_tensor_vector_map_.emplace (
2433+ key_, std::make_unique<TensorListBuffer>(key_));
2434+ }
2435+ iter->second ->is_available = false ;
2436+ }
2437+ buffer_ptr_ = iter->second .get ();
2438+ }
2439+
2440+ TensorListBufferAllocator::~TensorListBufferAllocator () {
2441+ if (buffer_ptr_) {
2442+ buffer_ptr_->is_available = true ;
2443+
2444+ for (auto & tensor : buffer_ptr_->buffer ) {
2445+ tensor.reset ();
2446+ }
2447+ }
2448+ }
2449+ std::pair<PyObject*, ssize_t > GetPyArgumentInfo (const std::string& op_type,
2450+ const std::string& arg_name,
2451+ PyObject* args,
2452+ ssize_t arg_idx,
2453+ bool dispensable) {
2454+ PyObject* list = PyTuple_GET_ITEM (args, arg_idx);
2455+ ssize_t list_len = 0 ;
2456+ if (list == nullptr && !dispensable) {
2457+ PADDLE_THROW (common::errors::InvalidArgument (
2458+ " %s(): argument '%s' (position %d) must be list of Tensor, but got "
2459+ " None" ,
2460+ op_type,
2461+ arg_name,
2462+ arg_idx));
2463+ }
2464+ if (list == nullptr || list == Py_None) {
2465+ list_len = -1 ;
2466+ } else if (PyList_Check (list)) {
2467+ list_len = PyList_Size (list);
2468+ } else if (PyTuple_Check (list)) {
2469+ list_len = PyTuple_Size (list);
2470+ } else {
2471+ PADDLE_THROW (common::errors::InvalidArgument (
2472+ " %s(): argument '%s' (position %d) must be list of Tensors, but got "
2473+ " %s" ,
2474+ op_type,
2475+ arg_name,
2476+ arg_idx,
2477+ (reinterpret_cast <PyTypeObject*>(list->ob_type ))->tp_name ));
2478+ }
2479+ return std::make_pair (list, list_len);
2480+ }
2481+
2482+ std::vector<paddle::Tensor>& GetTensorListFromArgsWithBuffer (
2483+ const std::string& op_type,
2484+ const std::string& arg_name,
2485+ ssize_t arg_idx,
2486+ const phi::distributed::ProcessMesh* mesh,
2487+ PyObject* list,
2488+ ssize_t list_len,
2489+ const TensorListBufferAllocator& allocator) {
2490+ auto & result = allocator.GetAllocatedBuffer ();
2491+
2492+ const phi::distributed::ProcessMesh* local_mesh = nullptr ;
2493+ ssize_t mesh_start_index = -1 ;
2494+
2495+ if (PyList_Check (list)) {
2496+ for (Py_ssize_t i = 0 ; i < list_len; i++) {
2497+ PyObject* tensor_obj = PyList_GetItem (list, i);
2498+ PADDLE_ENFORCE_EQ (
2499+ PyObject_TypeCheck (tensor_obj, p_tensor_type),
2500+ true ,
2501+ common::errors::InvalidArgument (
2502+ " %s(): argument '%s' (position %d) must be list of Tensors" ,
2503+ op_type,
2504+ arg_name,
2505+ arg_idx));
2506+ paddle::Tensor& tensor =
2507+ reinterpret_cast <TensorObject*>(tensor_obj)->tensor ;
2508+ if (local_mesh) {
2509+ ConvertToDistTensor (&tensor, local_mesh);
2510+ } else {
2511+ if (tensor.is_dist_tensor ()) {
2512+ local_mesh = &(std::static_pointer_cast<phi::distributed::DistTensor>(
2513+ tensor.impl ())
2514+ ->process_mesh ());
2515+ mesh_start_index = i;
2516+ }
2517+ }
2518+ result[i] = tensor;
2519+ }
2520+ for (Py_ssize_t i = 0 ; i < mesh_start_index; i++) {
2521+ paddle::Tensor& tensor =
2522+ reinterpret_cast <TensorObject*>(PyList_GetItem (list, i))->tensor ;
2523+ ConvertToDistTensor (&tensor, local_mesh);
2524+ result[i] = tensor;
2525+ }
2526+
2527+ } else if (PyTuple_Check (list)) {
2528+ for (Py_ssize_t i = 0 ; i < list_len; i++) {
2529+ PyObject* tensor_obj = PyTuple_GetItem (list, i);
2530+ PADDLE_ENFORCE_EQ (
2531+ PyObject_TypeCheck (tensor_obj, p_tensor_type),
2532+ true ,
2533+ common::errors::InvalidArgument (
2534+ " %s(): argument '%s' (position %d) must be list of Tensors" ,
2535+ op_type,
2536+ arg_name,
2537+ arg_idx));
2538+ paddle::Tensor& tensor =
2539+ reinterpret_cast <TensorObject*>(tensor_obj)->tensor ;
2540+ if (local_mesh) {
2541+ ConvertToDistTensor (&tensor, local_mesh);
2542+ } else {
2543+ if (tensor.is_dist_tensor ()) {
2544+ local_mesh = &(std::static_pointer_cast<phi::distributed::DistTensor>(
2545+ tensor.impl ())
2546+ ->process_mesh ());
2547+ mesh_start_index = i;
2548+ }
2549+ }
2550+ result[i] = tensor;
2551+ }
2552+ for (Py_ssize_t i = 0 ; i < mesh_start_index; i++) {
2553+ paddle::Tensor& tensor =
2554+ reinterpret_cast <TensorObject*>(PyTuple_GetItem (list, i))->tensor ;
2555+ ConvertToDistTensor (&tensor, local_mesh);
2556+ result[i] = tensor;
2557+ }
2558+ }
2559+ return result;
2560+ }
2561+
24142562paddle::Place CastPyArg2Place (PyObject* obj,
24152563 const std::string& op_type,
24162564 ssize_t arg_pos) {
0 commit comments