2020namespace paddle {
2121namespace framework {
2222
23- namespace {
24-
25- /*
26- * Parse the var_ids that need to be associated with an event.
27- * The caller should guarantee front_op and back_op satisfy the
28- * following conditions:
29- * 1. kQueueAsync -> kQueueAsync
30- * 2. kQueueAsync -> kQueueSync
31- *
32- * For example: matmul(gpu) -> out_var -> memcpy_d2h
33- * out_var should be associated with an event.
34- */
35- std::vector<size_t > ParseEventVarIds (const Instruction& cur_instr,
36- const Instruction& next_instr) {
37- std::unordered_set<size_t > unique_var_ids;
38- for (auto & item : cur_instr.output_index_ ) {
39- unique_var_ids.insert (item.second .begin (), item.second .end ());
40- }
41-
42- std::vector<size_t > new_event_var_ids;
43- for (auto & item : next_instr.input_index_ ) {
44- for (auto var_id : item.second ) {
45- if (unique_var_ids.count (var_id) > 0 ) {
46- new_event_var_ids.push_back (var_id);
47- }
48- }
49- }
50- return new_event_var_ids;
51- }
52-
53- void AssociateInputWithEvents (
54- const platform::Place& place, const std::vector<size_t >& new_event_var_id,
55- Instruction* next_instr,
56- std::map<size_t , std::shared_ptr<platform::DeviceEvent>>* var_id2event,
57- bool is_sync) {
58- for (auto var_id : new_event_var_id) {
59- if (var_id2event->count (var_id) == 0 ) {
60- auto device_event = std::make_shared<platform::DeviceEvent>(
61- place, platform::GenerateDeviceEventFlag ());
62- var_id2event->emplace (var_id, std::move (device_event));
63- }
64- // Add events for next_instr.inputs
65- next_instr->intput_events_ .emplace_back (var_id, var_id2event->at (var_id),
66- is_sync);
67- }
68- }
69-
70- void ParseDirectAndEventRunOps (
71- const platform::Place& place, const std::vector<OpFuncNode>& op_func_nodes,
72- const std::vector<size_t >& downstream_ops, size_t op_index,
73- std::map<size_t , std::shared_ptr<platform::DeviceEvent>>* var_id2event,
74- std::vector<Instruction>* instructions) {
75- auto & op_func_type = op_func_nodes[op_index].type_ ;
76- auto & cur_instr = instructions->at (op_index);
77- auto & next_instruction = cur_instr.next_instruction_ ;
78-
79- if (op_func_type == OpFuncType::kQueueSync ) {
80- // all downstream ops of kQueueSync can directly run, such as CPU -> Any
81- next_instruction.direct_run_ = downstream_ops;
82- } else { // kQueueAsync
83- std::vector<size_t > event_var_ids;
84- for (auto next_op_id : downstream_ops) {
85- auto & next_instr = instructions->at (next_op_id);
86- // case 1: GPU -> GPU(same stream)
87- if (cur_instr.dev_ctx_ == next_instr.dev_ctx_ ) {
88- next_instruction.direct_run_ .emplace_back (next_op_id);
89- continue ;
90- }
91- // Always insert events between different stream
92- auto new_event_var_ids = ParseEventVarIds (cur_instr, next_instr);
93- event_var_ids.insert (event_var_ids.end (), new_event_var_ids.begin (),
94- new_event_var_ids.end ());
95-
96- bool is_sync =
97- (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync );
98- AssociateInputWithEvents (place, new_event_var_ids, &next_instr,
99- var_id2event, is_sync);
100-
101- if (is_sync) { // GPU -> CPU
102- next_instruction.synchronize_run_ .emplace_back (next_op_id);
103- } else { // GPU -> GPU(different stream)
104- next_instruction.event_wait_run_ .emplace_back (next_op_id);
105- }
106- }
107- // Create events for these cross-stream vars
108- VLOG (3 ) << cur_instr.kernel_func_ .operator_base_ ->Type ()
109- << " event_var_ids.size: " << event_var_ids.size ();
110- for (auto var_id : event_var_ids) {
111- cur_instr.output_events_ .emplace_back (var_id, var_id2event->at (var_id),
112- false /* not used*/ );
113- }
114- }
115- }
116- } // namespace
117-
11823InterpreterCore::InterpreterCore (const platform::Place& place,
11924 const ProgramDesc& main_prog,
12025 VariableScope* global_scope,
@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
12328 : place_(place),
12429 main_program_ (main_prog),
12530 global_scope_(global_scope),
126- d2h_ctx_pool_({place}),
127- h2d_ctx_pool_({place}) {
31+ stream_analyzer_(place) {
12832 is_build_ = false ;
12933
13034 feed_names_ = feed_names;
@@ -199,7 +103,7 @@ void InterpreterCore::Convert() {
199103 Instruction temp_inst;
200104 auto * op_base = op_list_[i];
201105 temp_inst.dev_ctx_ =
202- ParseDeviceContextForInstruction (vec_func_list_[i], *op_base);
106+ stream_analyzer_. ParseDeviceContext (vec_func_list_[i], *op_base);
203107 temp_inst.kernel_func_ .compute_func_ = vec_func_list_[i].kernel_func_ ;
204108 temp_inst.kernel_func_ .operator_base_ = op_base;
205109 temp_inst.input_index_ = vec_func_list_[i].input_index ;
@@ -270,8 +174,8 @@ void InterpreterCore::Convert() {
270174 }
271175 }
272176
273- ParseDirectAndEventRunOps (place_, vec_func_list_, filter_next, i,
274- &var_id2event_, & vec_instruction_);
177+ stream_analyzer_. Schedule ( vec_func_list_, filter_next, i,
178+ &vec_instruction_);
275179
276180 for (auto inst_id : filter_next) {
277181 dependecy_count_[inst_id]++;
@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList(
361265 working_queue.pop ();
362266 auto & instr_node = vec_instr[instr_id];
363267 // step1 : stream_wait (non-block host) or sync (block host)
364- StreamWaitEventOrSync (instr_node);
268+ event_manager_. WaitEvent (instr_node, place_ );
365269 // step2: run instruction
366270 RunInstruction (instr_node);
367271 ++run_op_number;
@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList(
371275 }
372276
373277 // step3: insert event for out_vars if needed
374- RecordEventInstruction (instr_node, vec_func_list_[instr_id]);
278+ event_manager_. RecordEvent (instr_node, vec_func_list_[instr_id], place_ );
375279
376280 // step4: update working_queue
377281 auto & next_instr = instr_node.next_instruction_ .all_next_ops_ ;
@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun(
450354 return dry_run_profiler_.GetCostInfo ();
451355}
452356
453- platform::DeviceContext* InterpreterCore::ParseDeviceContextForInstruction (
454- const OpFuncNode& op_func_node, const OperatorBase& op_base) {
455- auto & op_type = op_base.Type ();
456- auto * dev_ctx = op_func_node.dev_ctx_ ;
457- if (op_type == interpretercore::kMemcpyH2D ) {
458- VLOG (3 ) << " Get dev_ctx from d2h_context_pool_" ;
459- dev_ctx = d2h_ctx_pool_.Get (place_);
460- } else if (op_type == interpretercore::kMemcpyD2H ) {
461- VLOG (3 ) << " Get dev_ctx from h2d_context_pool_" ;
462- dev_ctx = h2d_ctx_pool_.Get (place_);
463- }
464-
465- return dev_ctx;
466- }
467-
468- void InterpreterCore::RecordEventInstruction (const Instruction& instruction,
469- const OpFuncNode& op_func_node) {
470- // If InterpreterCore in on CPUPlace, do nothing.
471- if (platform::is_cpu_place (place_)) return ;
472-
473- for (auto & event : instruction.output_events_ ) {
474- VLOG (3 ) << " Record event in out_var_id: " << event.var_id_ ;
475- event.event_ ->Record (instruction.dev_ctx_ );
476- }
477- }
478-
479- void InterpreterCore::WaitOrSync (const std::vector<EventInter>& events,
480- const platform::DeviceContext* dev_ctx) {
481- for (auto & event_iter : events) {
482- if (event_iter.is_sync_ ) {
483- VLOG (3 ) << " host sync wait in_var_id " << event_iter.var_id_ ;
484- event_iter.event_ ->Wait (platform::kCPU , dev_ctx);
485- } else {
486- VLOG (3 ) << " stream async wait in_var_id " << event_iter.var_id_ ;
487- event_iter.event_ ->Wait (platform::kCUDA , dev_ctx);
488- }
489- }
490- }
491-
492- void InterpreterCore::StreamWaitEventOrSync (const Instruction& instruction) {
493- // If InterpreterCore in on CPUPlace, do nothing.
494- if (platform::is_cpu_place (place_)) return ;
495-
496- VLOG (3 ) << " Deal StreamWaitEventOrSync for "
497- << instruction.kernel_func_ .operator_base_ ->Type ();
498- auto * dev_ctx = instruction.dev_ctx_ ;
499-
500- WaitOrSync (instruction.intput_events_ , dev_ctx);
501- }
502357} // namespace framework
503358} // namespace paddle
0 commit comments