2525#include " paddle/fluid/framework/program_desc.h"
2626#include " paddle/fluid/framework/scope.h"
2727#include " paddle/fluid/framework/tensor.h"
28+ #include " paddle/utils/string/split.h"
2829
2930namespace paddle {
3031namespace distributed {
@@ -53,7 +54,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
5354 } else if (input_data.dtype == DistModelDataType::INT32) {
5455 input_tensor_ptr = input_tensor->mutable_data <int32_t >(dims, place);
5556 } else if (input_data.dtype == DistModelDataType::FLOAT16) {
56- input_tensor_ptr = input_tensor->mutable_data <float16>(dims, place);
57+ input_tensor_ptr =
58+ input_tensor->mutable_data <paddle::platform::float16>(dims, place);
5759 } else {
5860 LOG (ERROR) << " unsupported feed type " << input_data.dtype ;
5961 return false ;
@@ -136,11 +138,14 @@ class DistModelTimer {
136138
137139bool DistModel::Init () {
138140 carrier_id_ = " inference" ;
139- bool init_method = (!config_.model_dir .empty () || config_.program_desc );
140- PADDLE_ENFORCE_EQ (init_method, true ,
141- platform::errors::InvalidArgument (
142- " One of model dir or program desc must be provided to "
143- " dist model inference." ));
141+ bool init_method =
142+ !config_.model_dir .empty () || config_.program_desc ||
143+ (!config_.program_path .empty () && !config_.param_path .empty ());
144+ PADDLE_ENFORCE_EQ (
145+ init_method, true ,
146+ platform::errors::InvalidArgument (
147+ " One of model dir, program desc or (program_path, param_path) pair "
148+ " must be provided to dist model inference." ));
144149 if (config_.program_desc ) {
145150 PADDLE_ENFORCE_NOT_NULL (
146151 config_.scope , platform::errors::InvalidArgument (
@@ -178,6 +183,7 @@ bool DistModel::Init() {
178183}
179184
180185bool DistModel::PreparePlace () {
186+ VLOG (3 ) << " DistModel is going to set place for: " << config_.place ;
181187 if (config_.place == " GPU" ) {
182188 place_ = paddle::platform::CUDAPlace (config_.device_id );
183189 } else if (config_.place == " CPU" ) {
@@ -186,10 +192,151 @@ bool DistModel::PreparePlace() {
186192 PADDLE_THROW (platform::errors::InvalidArgument (
187193 " Place must be choosen from GPU or CPU, but got %s." , config_.place ));
188194 }
195+ VLOG (3 ) << " DistModel prepare place success" ;
196+ return true ;
197+ }
198+
199+ bool DistModel::LoadConverterConfig () {
200+ VLOG (3 ) << " Going to load converter config from: " << config_.comm_init_config
201+ << " \n " ;
202+ std::ifstream fin (config_.comm_init_config , std::ios::in);
203+ PADDLE_ENFORCE_EQ (
204+ static_cast <bool >(fin.is_open ()), true ,
205+ platform::errors::NotFound (
206+ " Cannot open file %s, please confirm whether the file is normal." ,
207+ config_.comm_init_config ));
208+ std::string line;
209+ bool ring_to_rank;
210+ // Reading config from file, the config file should like these format
211+ // [ring_id -> ranks]
212+ // 0,0,1,2,3
213+ // 1,0,1
214+ // 2,2,3
215+ // 21,0,1
216+ // 22,1,2
217+ // 23,2,3
218+ // [rank -> ring_ids]
219+ // 0,0,1,21
220+ // 1,0,1,21,22
221+ // 2,0,2,22,23
222+ // 3,0,2,23
223+ while (std::getline (fin, line)) {
224+ std::vector<std::string> one_line = paddle::string::Split (line, ' ,' );
225+ if (one_line.size () == 1 ) {
226+ // start a new section of the config
227+ if (line == " [ring_id -> ranks]" ) {
228+ ring_to_rank = true ;
229+ } else if (line == " [rank -> ring_ids]" ) {
230+ ring_to_rank = false ;
231+ }
232+ } else {
233+ // parse key - values pairs in one section
234+ int64_t key = std::stoll (one_line[0 ]);
235+ for (size_t i = 1 ; i < one_line.size (); ++i) {
236+ int64_t val = std::stoll (one_line[i]);
237+ if (ring_to_rank) {
238+ if (config_.ring_id_to_ranks_ .find (key) ==
239+ config_.ring_id_to_ranks_ .end ()) {
240+ config_.ring_id_to_ranks_ [key] = std::vector<int64_t >();
241+ }
242+ config_.ring_id_to_ranks_ [key].emplace_back (val);
243+ } else {
244+ if (config_.rank_to_ring_ids_ .find (key) ==
245+ config_.rank_to_ring_ids_ .end ()) {
246+ config_.rank_to_ring_ids_ [key] = std::vector<int64_t >();
247+ }
248+ config_.rank_to_ring_ids_ [key].emplace_back (val);
249+ }
250+ // NOTE: add more configuration sections here
251+ }
252+ }
253+ }
254+ std::stringstream ss;
255+ ss << " Loaded the following converter config:\n " ;
256+ ss << " ring_id_to_ranks:\n " ;
257+ for (auto pair : config_.ring_id_to_ranks_ ) {
258+ int64_t key = pair.first ;
259+ ss << " \t " << key << " \t ->\t " ;
260+ for (auto value : pair.second ) {
261+ ss << value << " \t " ;
262+ }
263+ ss << " \n " ;
264+ }
265+ ss << " rank_to_ring_ids:\n " ;
266+ for (auto pair : config_.rank_to_ring_ids_ ) {
267+ int64_t key = pair.first ;
268+ ss << " \t " << key << " \t ->\t " ;
269+ for (auto value : pair.second ) {
270+ ss << value << " \t " ;
271+ }
272+ ss << " \n " ;
273+ }
274+ VLOG (3 ) << ss.str ();
275+ return true ;
276+ }
277+
278+ std::vector<std::string> DistModel::GetOutputNames () {
279+ std::vector<std::string> rst;
280+ std::stringstream ss;
281+ ss << " DistModel GetOutputNames: " ;
282+ for (const auto &pair : idx_to_fetches_) {
283+ ss << pair.second << " , " ;
284+ rst.emplace_back (pair.second );
285+ }
286+ ss << " \n " ;
287+ VLOG (3 ) << ss.str ();
288+ return rst;
289+ }
290+
291+ std::vector<std::string> DistModel::GetInputNames () {
292+ std::vector<std::string> rst;
293+ std::stringstream ss;
294+ ss << " DistModel GetInputNames: " ;
295+ for (const auto &pair : idx_to_feeds_) {
296+ ss << pair.second << " , " ;
297+ rst.emplace_back (pair.second );
298+ }
299+ ss << " \n " ;
300+ VLOG (3 ) << ss.str ();
301+ return rst;
302+ }
303+
304+ framework::Scope *DistModel::GetScope () {
305+ VLOG (3 ) << " DistModel GetScope()" ;
306+ return scope_.get ();
307+ }
308+
309+ paddle::platform::Place DistModel::GetPlace () {
310+ VLOG (3 ) << " DistModel GetPlace()" ;
311+ return place_;
312+ }
313+
314+ bool DistModel::ZeroCopyRun () {
315+ VLOG (3 ) << " DistModel run with ZeroCopy." ;
316+
317+ DistModelTimer timer;
318+ timer.tic ();
319+ double start_time = timer.toc ();
320+
321+ fleet_exe->Run (carrier_id_);
322+
323+ double end_time = timer.toc ();
324+ if (config_.enable_timer ) {
325+ LOG (INFO) << " DistModel finish inf, cost " << end_time - start_time << " ms" ;
326+ } else {
327+ VLOG (3 ) << " DistModel finish inf." ;
328+ }
189329 return true ;
190330}
191331
192332bool DistModel::CommInit () {
333+ VLOG (3 ) << " DistModel CommInit()" ;
334+ if (!config_.comm_init_config .empty ()) {
335+ if (!LoadConverterConfig ()) {
336+ VLOG (3 ) << " Load converter config failed, DistModel init failed." ;
337+ return false ;
338+ }
339+ }
193340 std::unique_ptr<framework::ProgramDesc> comm_init_program (
194341 new framework::ProgramDesc ());
195342 framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock (0 );
@@ -278,25 +425,32 @@ void DistModel::InsertCommOp(std::string tmp_var_name, int nranks, int rank,
278425}
279426
280427bool DistModel::PrepareScope () {
428+ VLOG (3 ) << " DistModel PrepareScope()" ;
281429 scope_.reset (new framework::Scope ());
430+ VLOG (3 ) << " DistModel prepare scope success" ;
282431 return true ;
283432}
284433
285434bool DistModel::PrepareProgram () {
435+ VLOG (3 ) << " DistModel PrepareProgram()" ;
286436 if (!LoadProgram ()) {
287437 return false ;
288438 }
289439 if (!LoadParameters ()) {
290440 return false ;
291441 }
442+ VLOG (3 ) << " DistModel prepare program success" ;
292443 return true ;
293444}
294445
295446bool DistModel::LoadProgram () {
296- VLOG (3 ) << " Loading program from " << config_.model_dir ;
297- PADDLE_ENFORCE_NE (config_.model_dir , " " , platform::errors::InvalidArgument (
298- " Model dir must be provided." ));
299- std::string model_path = config_.model_dir + " .pdmodel" ;
447+ std::string model_path = config_.model_dir .empty ()
448+ ? config_.program_path
449+ : (config_.model_dir + " .pdmodel" );
450+ PADDLE_ENFORCE_NE (model_path, " " ,
451+ platform::errors::InvalidArgument (
452+ " One of model dir or program_path must be provided." ));
453+ VLOG (3 ) << " Loading program from " << model_path;
300454 framework::proto::ProgramDesc program_proto;
301455 std::string pb_content;
302456 // Read binary
@@ -318,7 +472,6 @@ bool DistModel::LoadProgram() {
318472}
319473
320474bool DistModel::LoadParameters () {
321- VLOG (3 ) << " Loading parameters from " << config_.model_dir ;
322475 PADDLE_ENFORCE_NOT_NULL (program_.get (),
323476 platform::errors::PreconditionNotMet (
324477 " The program should be loaded first." ));
@@ -346,7 +499,13 @@ bool DistModel::LoadParameters() {
346499 }
347500 }
348501
349- std::string param_path = config_.model_dir + " .pdiparams" ;
502+ std::string param_path = config_.model_dir .empty ()
503+ ? config_.param_path
504+ : (config_.model_dir + " .pdiparams" );
505+ PADDLE_ENFORCE_NE (param_path, " " ,
506+ platform::errors::InvalidArgument (
507+ " One of model dir or param_path must be provided." ));
508+ VLOG (3 ) << " Loading parameters from " << param_path;
350509 // sort paramlist to have consistent ordering
351510 std::sort (params.begin (), params.end ());
352511 // append just the load_combine op
@@ -370,10 +529,11 @@ bool DistModel::LoadParameters() {
370529}
371530
372531bool DistModel::PrepareFleetExe () {
532+ VLOG (3 ) << " DistModel PrepareFleetExe()" ;
373533 task_node_.reset (new TaskNode (program_.get (), config_.local_rank ));
374534 // With auto cut, there is no concept of pp, no need to add dependency.
375535 task_node_->SetType (" Compute" );
376- task_node_->Init ();
536+ task_node_->Init (config_. use_feed_fetch_ops );
377537 executor_desc_ = FleetExecutorDesc ();
378538 executor_desc_.set_cur_rank (config_.local_rank );
379539 std::unordered_map<int64_t , int64_t > id_to_rank;
@@ -385,11 +545,13 @@ bool DistModel::PrepareFleetExe() {
385545 }
386546 fleet_exe.reset (new FleetExecutor (executor_desc_));
387547 fleet_exe->Init (carrier_id_, *(program_.get ()), scope_.get (), place_, 1 ,
388- {task_node_.get ()}, id_to_rank);
548+ {task_node_.get ()}, id_to_rank, force_root_scope_var_names_);
549+ VLOG (3 ) << " DistModel prepare fleet exe success." ;
389550 return true ;
390551}
391552
392553bool DistModel::PrepareFeedAndFetch () {
554+ VLOG (3 ) << " DistModel PrepareFeedAndFetch()" ;
393555 for (auto *op : program_->Block (0 ).AllOps ()) {
394556 if (op->Type () == " feed" ) {
395557 VLOG (3 ) << " feed op with feed var: " << op->Output (" Out" )[0 ];
@@ -399,6 +561,15 @@ bool DistModel::PrepareFeedAndFetch() {
399561 }
400562 feeds_[idx] = op;
401563 std::string var_name = op->Output (" Out" )[0 ];
564+ // NOTE: Vars of feed fetch ops are not persistable,
565+ // which will result in that those vars will be created in
566+ // the subscope (microscope) in fleet executor. This will
567+ // cause that the GetInputTensor/GetOutputTensor funct
568+ // in analysis predictor cannot find those vars in the scope
569+ // returned by the DistModel, since DistModel only return the
570+ // root scope. So, those vars must be forced to be created
571+ // in the root scope instead of in the microscope.
572+ force_root_scope_var_names_.emplace_back (var_name);
402573 feed_names_[var_name] = idx;
403574 idx_to_feeds_[idx] = var_name;
404575 framework::VarDesc *real_var = program_->Block (0 ).FindVar (var_name);
@@ -428,7 +599,9 @@ bool DistModel::PrepareFeedAndFetch() {
428599 fetches_.resize (idx + 1 );
429600 }
430601 fetches_[idx] = op;
431- idx_to_fetches_[idx] = op->Input (" X" )[0 ];
602+ std::string var_name = op->Input (" X" )[0 ];
603+ force_root_scope_var_names_.emplace_back (var_name);
604+ idx_to_fetches_[idx] = var_name;
432605 }
433606 }
434607
@@ -440,6 +613,7 @@ bool DistModel::PrepareFeedAndFetch() {
440613 LOG (ERROR) << " No fetch op in the inf program, please check the program." ;
441614 return false ;
442615 }
616+ VLOG (3 ) << " DistModel prepare feed and fetch success." ;
443617 return true ;
444618}
445619
@@ -508,7 +682,7 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
508682 rst = FetchResult<int32_t >(fetch, output);
509683 output->dtype = DistModelDataType::INT32;
510684 } else if (type == framework::proto::VarType::FP16) {
511- rst = FetchResult<float16>(fetch, output);
685+ rst = FetchResult<paddle::platform:: float16>(fetch, output);
512686 output->dtype = DistModelDataType::FLOAT16;
513687 } else {
514688 LOG (ERROR) << " DistModel meets unknown fetch data type. DistModel only "
0 commit comments