@@ -46,6 +46,48 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
4646 dense_grad_names_[table_id][j] = table.dense_grad_name (j);
4747 }
4848 }
49+ InitializeGPUServer (trainer_desc);
50+ scale_datanorm_ = trainer_desc.scale_datanorm ();
51+ int place_num = trainer_desc.worker_places_size ();
52+ const std::vector<paddle::framework::DataFeed*> readers =
53+ dataset->GetReaders ();
54+ dump_file_num_ = trainer_desc.dump_file_num ();
55+ user_define_dump_filename_ = trainer_desc.user_define_dump_filename ();
56+ std::vector<int > dev_ids;
57+ for (int i = 0 ; i < place_num; ++i) {
58+ int num = trainer_desc.worker_places (i);
59+ platform::CUDAPlace place = platform::CUDAPlace (num);
60+ places_.push_back (place);
61+ dev_ids.push_back (num);
62+ }
63+ for (int i = 0 ; i < trainer_desc.downpour_param ().stat_var_names_size ();
64+ i++) {
65+ need_merge_var_names_.push_back (
66+ trainer_desc.downpour_param ().stat_var_names (i));
67+ }
68+ VLOG (3 ) << " going to initialize pull dense worker" ;
69+ SetDebug (trainer_desc.debug ());
70+ trainer_desc_ = trainer_desc;
71+ workers_.resize (place_num);
72+ for (int i = 0 ; i < place_num; ++i) {
73+ workers_[i] = DeviceWorkerFactory::CreateDeviceWorker (
74+ trainer_desc.device_worker_name ());
75+ workers_[i]->SetDeviceIndex (i);
76+ workers_[i]->SetNeedDumpField (need_dump_field_);
77+ workers_[i]->SetNeedDumpParam (need_dump_param_);
78+ workers_[i]->SetDumpFieldVector (dump_fields_);
79+ workers_[i]->SetDumpParamVector (dump_param_);
80+ workers_[i]->InitRandomDumpConfig (trainer_desc);
81+ workers_[i]->SetDataFeed (readers[i]);
82+ workers_[i]->SetPlace (places_[i]);
83+ workers_[i]->SetReaderPlace (places_[i]);
84+ workers_[i]->Initialize (trainer_desc);
85+ workers_[i]->SetWorkerNum (place_num);
86+ }
87+ return ;
88+ }
89+
90+ void PSGPUTrainer::InitializeGPUServer (const TrainerDesc& trainer_desc) {
4991 // add for hbmps optimizer config
5092 auto fleet_desc_str = trainer_desc.fleet_desc ();
5193 google::protobuf::TextFormat::ParseFromString (fleet_desc_str, &_ps_param);
@@ -203,45 +245,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
203245
204246 auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance ();
205247 ps_gpu_wrapper->InitializeGPUServer (config);
206-
207- scale_datanorm_ = trainer_desc.scale_datanorm ();
208- int place_num = trainer_desc.worker_places_size ();
209- const std::vector<paddle::framework::DataFeed*> readers =
210- dataset->GetReaders ();
211- dump_file_num_ = trainer_desc.dump_file_num ();
212- user_define_dump_filename_ = trainer_desc.user_define_dump_filename ();
213- std::vector<int > dev_ids;
214- for (int i = 0 ; i < place_num; ++i) {
215- int num = trainer_desc.worker_places (i);
216- platform::CUDAPlace place = platform::CUDAPlace (num);
217- places_.push_back (place);
218- dev_ids.push_back (num);
219- }
220- for (int i = 0 ; i < trainer_desc.downpour_param ().stat_var_names_size ();
221- i++) {
222- need_merge_var_names_.push_back (
223- trainer_desc.downpour_param ().stat_var_names (i));
224- }
225- VLOG (3 ) << " going to initialize pull dense worker" ;
226- SetDebug (trainer_desc.debug ());
227- trainer_desc_ = trainer_desc;
228- workers_.resize (place_num);
229- for (int i = 0 ; i < place_num; ++i) {
230- workers_[i] = DeviceWorkerFactory::CreateDeviceWorker (
231- trainer_desc.device_worker_name ());
232- workers_[i]->SetDeviceIndex (i);
233- workers_[i]->SetNeedDumpField (need_dump_field_);
234- workers_[i]->SetNeedDumpParam (need_dump_param_);
235- workers_[i]->SetDumpFieldVector (dump_fields_);
236- workers_[i]->SetDumpParamVector (dump_param_);
237- workers_[i]->InitRandomDumpConfig (trainer_desc);
238- workers_[i]->SetDataFeed (readers[i]);
239- workers_[i]->SetPlace (places_[i]);
240- workers_[i]->SetReaderPlace (places_[i]);
241- workers_[i]->Initialize (trainer_desc);
242- workers_[i]->SetWorkerNum (place_num);
243- }
244- return ;
245248}
246249
247250std::string PSGPUTrainer::GetDumpPath (int tid) {
0 commit comments