Skip to content

Commit fc13a47

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into pten_matmul_grad
2 parents c113ab5 + 5cf0bb7 commit fc13a47

File tree

80 files changed

+1990
-1001
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+1990
-1001
lines changed

paddle/fluid/distributed/fleet_executor/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime
1919

2020
if(WITH_DISTRIBUTE)
2121
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
22+
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
23+
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
24+
endif()
2225
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2326
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2427
set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

paddle/fluid/distributed/fleet_executor/carrier.cc

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
16-
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
16+
#include "paddle/fluid/distributed/fleet_executor/global.h"
1717
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
1818
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
1919
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
@@ -71,17 +71,13 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
7171

7272
bool Carrier::EnqueueInterceptorMessage(
7373
const InterceptorMessage& interceptor_message) {
74-
if (interceptor_message.ctrl_message()) {
75-
VLOG(3) << "Receiving control message from rank "
76-
<< interceptor_message.src_id() << " to rank "
77-
<< interceptor_message.dst_id();
78-
// for barrier
79-
msg_bus_->IncreaseBarrierCount();
80-
} else {
81-
int64_t dst_id = interceptor_message.dst_id();
82-
Interceptor* dst_interceptor = GetInterceptor(dst_id);
83-
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
84-
}
74+
PADDLE_ENFORCE_EQ(
75+
interceptor_message.ctrl_message(), false,
76+
platform::errors::Fatal(
77+
"Control message should be only send inter rank using message bus."));
78+
int64_t dst_id = interceptor_message.dst_id();
79+
Interceptor* dst_interceptor = GetInterceptor(dst_id);
80+
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
8581
return true;
8682
}
8783

@@ -106,11 +102,6 @@ void Carrier::WakeUp() {
106102
}
107103

108104
void Carrier::Start() {
109-
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
110-
platform::errors::PreconditionNotMet(
111-
"Using message bus since it has not been initialized. "
112-
"Please invoke MessageBus::Init() before using it or "
113-
"neccessary components are not ready."));
114105
PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet(
115106
"Using carrier before initialized."));
116107
for (int64_t id : source_interceptor_ids_) {
@@ -154,19 +145,10 @@ bool Carrier::Send(const InterceptorMessage& msg) {
154145
<< " to interceptor " << dst_id << ", which are in the same ranks.";
155146
return EnqueueInterceptorMessage(msg);
156147
} else {
157-
PADDLE_ENFORCE_NOT_NULL(
158-
msg_bus_.get(),
159-
platform::errors::Unavailable("Message bus is released accidently"));
160-
PADDLE_ENFORCE_EQ(
161-
msg_bus_->IsInit(), true,
162-
platform::errors::PreconditionNotMet(
163-
"Using message bus since it has not been initialized. "
164-
"Please invoke MessageBus::Init() before using it or "
165-
"neccessary components are not ready."));
166148
VLOG(3) << "Send a message from interceptor " << src_id
167149
<< " to interceptor " << dst_id
168150
<< ", which are in different ranks.";
169-
return msg_bus_->Send(dst_rank, msg);
151+
return GlobalVal<MessageBus>::Get()->Send(dst_rank, msg);
170152
}
171153
}
172154

paddle/fluid/distributed/fleet_executor/carrier.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,6 @@ class Carrier final {
7373
Interceptor* SetInterceptor(int64_t interceptor_id,
7474
std::unique_ptr<Interceptor>);
7575

76-
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
77-
msg_bus_ = msg_bus;
78-
}
79-
8076
void Start();
8177

8278
bool IsInit() const;
@@ -107,7 +103,6 @@ class Carrier final {
107103
framework::Scope* minibatch_scope_;
108104
paddle::platform::Place place_;
109105
paddle::platform::DeviceContext* dev_ctx_{nullptr};
110-
std::shared_ptr<MessageBus> msg_bus_;
111106
int64_t rank_;
112107
std::string carrier_id_;
113108
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
16-
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
16+
#include "paddle/fluid/distributed/fleet_executor/global.h"
1717
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
1818
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
1919
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
@@ -32,6 +32,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
3232
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
3333
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
3434
"Error occurs while parsing string to proto"));
35+
// Message bus will be created and inited only once
36+
GlobalVal<MessageBus>::Create();
37+
InitMessageBus();
3538
}
3639

3740
FleetExecutor::~FleetExecutor() {
@@ -81,21 +84,16 @@ void FleetExecutor::Init(
8184
CopyParameters(i, program_desc);
8285
}
8386
VLOG(5) << runtime_graph_->DebugString();
84-
msg_bus_ = std::make_shared<MessageBus>();
8587
Carrier* carrier =
8688
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
8789
carrier_ids_.insert(carrier_id);
88-
GlobalVal<std::string>::Set(carrier_id);
89-
// TODO(liyurui): Maybe message bus should be created only once
90+
// Set current running carrier
91+
GlobalVal<std::string>::Set(new std::string(carrier_id));
9092
InitCarrier(carrier);
91-
InitMessageBus();
92-
93-
// Wait for all message bus connected.
94-
msg_bus_->Barrier();
93+
GlobalVal<MessageBus>::Get()->Barrier();
9594
}
9695

9796
void FleetExecutor::InitCarrier(Carrier* carrier) {
98-
carrier->SetMsgBus(msg_bus_);
9997
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
10098
runtime_graph_->interceptor_id_to_node(), root_scope_,
10199
minibatch_scope_, microbatch_scopes_, place_);
@@ -131,14 +129,18 @@ void FleetExecutor::InitMessageBus() {
131129
VLOG(3) << "The number of ranks are "
132130
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
133131
VLOG(5) << ss.str();
134-
if (!msg_bus_->IsInit()) {
135-
msg_bus_->Init(cur_rank, rank_to_addr, addr);
136-
}
132+
GlobalVal<MessageBus>::Get()->Init(cur_rank, rank_to_addr, addr);
137133
}
138134

139135
void FleetExecutor::Run(const std::string& carrier_id) {
140-
GlobalMap<std::string, Carrier>::Get(carrier_id)->Start();
141-
GlobalVal<std::string>::Set(carrier_id);
136+
Carrier* carrier = GlobalMap<std::string, Carrier>::Get(carrier_id);
137+
// Set current running carrier
138+
if (*GlobalVal<std::string>::Get() != carrier_id) {
139+
GlobalVal<std::string>::Set(new std::string(carrier_id));
140+
// TODO(liyurui): Move barrier to service
141+
GlobalVal<MessageBus>::Get()->Barrier();
142+
}
143+
carrier->Start();
142144
for (auto* micro_scop : microbatch_scopes_) {
143145
// By default, we should delete all kid scopes after run executor because
144146
// some operators may create local scope when running, such as while_op.

paddle/fluid/distributed/fleet_executor/fleet_executor.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ class FleetExecutor final {
5555
framework::Scope* minibatch_scope_;
5656
platform::Place place_;
5757
std::vector<framework::Scope*> microbatch_scopes_;
58-
// The carriers under FleetExecutor will share message bus,
59-
// using shared_ptr to manage lifetime and condition race.
60-
std::shared_ptr<MessageBus> msg_bus_;
6158
std::unordered_set<std::string> carrier_ids_;
6259
};
6360

paddle/fluid/distributed/fleet_executor/global_map.h renamed to paddle/fluid/distributed/fleet_executor/global.h

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,41 @@
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/platform/enforce.h"
18+
1719
namespace paddle {
1820
namespace distributed {
1921

20-
// TODO(liyurui): Change this file to global.h
2122
template <typename T>
2223
class GlobalVal final {
2324
public:
24-
static T Get() { return *GetPtr(); }
25-
static T Set(T val) {
26-
auto* ptr = GetPtr();
27-
*ptr = val;
28-
return val;
25+
static T* Get() {
26+
T* ptr = GetPPtr()->get();
27+
PADDLE_ENFORCE_NOT_NULL(
28+
ptr, platform::errors::NotFound("This value is not global value."));
29+
return ptr;
30+
}
31+
template <typename... Args>
32+
static T* Create(Args&&... args) {
33+
auto* ptr = GetPPtr();
34+
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
35+
platform::errors::AlreadyExists(
36+
"This value is already a global value."));
37+
T* item = new T(std::forward<Args>(args)...);
38+
ptr->reset(item);
39+
return item;
40+
}
41+
42+
static T* Set(T* new_item) {
43+
auto* ptr = GetPPtr();
44+
ptr->reset(new_item);
45+
return ptr->get();
2946
}
3047

3148
private:
32-
static T* GetPtr() {
33-
static T value;
34-
return &value;
49+
static std::unique_ptr<T>* GetPPtr() {
50+
static std::unique_ptr<T> ptr;
51+
return &ptr;
3552
}
3653
};
3754

paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
!defined(PADDLE_WITH_ASCEND_CL)
1616
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
1717
#include "brpc/server.h"
18-
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
19-
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
18+
#include "paddle/fluid/distributed/fleet_executor/global.h"
19+
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
2020

2121
namespace paddle {
2222
namespace distributed {
@@ -29,9 +29,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
2929
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
3030
<< request->src_id() << " to interceptor " << request->dst_id()
3131
<< ", with the message: " << request->message_type();
32-
const auto& carrier_id = GlobalVal<std::string>::Get();
33-
bool flag = GlobalMap<std::string, Carrier>::Get(carrier_id)
34-
->EnqueueInterceptorMessage(*request);
32+
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
3533
response->set_rst(flag);
3634
}
3735

paddle/fluid/distributed/fleet_executor/message_bus.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <set>
1818
#include <thread>
1919

20+
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
21+
#include "paddle/fluid/distributed/fleet_executor/global.h"
2022
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
2123
#include "paddle/fluid/platform/gen_comm_id_helper.h"
2224

@@ -81,6 +83,10 @@ const std::string& MessageBus::GetAddr(int64_t rank) const {
8183

8284
bool MessageBus::Send(int64_t dst_rank,
8385
const InterceptorMessage& interceptor_message) {
86+
PADDLE_ENFORCE_EQ(
87+
IsInit(), true,
88+
platform::errors::PreconditionNotMet(
89+
"Using message bus since it has not been initialized."));
8490
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
8591
!defined(PADDLE_WITH_ASCEND_CL)
8692
int retry_time = 0; // message bus will retry sending for 10 times
@@ -155,6 +161,22 @@ void MessageBus::Barrier() {
155161
}
156162
}
157163

164+
bool MessageBus::DispatchMsgToCarrier(
165+
const InterceptorMessage& interceptor_message) {
166+
if (interceptor_message.ctrl_message()) {
167+
VLOG(3) << "Receiving control message from rank "
168+
<< interceptor_message.src_id() << " to rank "
169+
<< interceptor_message.dst_id();
170+
// for barrier
171+
IncreaseBarrierCount();
172+
return true;
173+
} else {
174+
const std::string& carrier_id = *GlobalVal<std::string>::Get();
175+
return GlobalMap<std::string, Carrier>::Get(carrier_id)
176+
->EnqueueInterceptorMessage(interceptor_message);
177+
}
178+
}
179+
158180
void MessageBus::ListenPort() {
159181
if (addr_ == "") {
160182
LOG(INFO) << "No need listen to port since training on single card.";

paddle/fluid/distributed/fleet_executor/message_bus.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class MessageBus final {
5454

5555
void IncreaseBarrierCount();
5656
void Barrier();
57+
bool DispatchMsgToCarrier(const InterceptorMessage& interceptor_message);
5758

5859
private:
5960
DISABLE_COPY_AND_ASSIGN(MessageBus);

paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License. */
1818
#include "gtest/gtest.h"
1919

2020
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
21-
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
21+
#include "paddle/fluid/distributed/fleet_executor/global.h"
2222
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
2323
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
2424
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
@@ -67,9 +67,8 @@ TEST(ComputeInterceptor, Compute) {
6767
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
6868
carrier->Init(0, {{0, 0}, {1, 0}});
6969

70-
auto msg_bus = std::make_shared<MessageBus>();
70+
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
7171
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
72-
carrier->SetMsgBus(msg_bus);
7372

7473
// FIXME: don't delete, otherwise interceptor will use undefined node
7574
TaskNode* node_a =

0 commit comments

Comments
 (0)