Skip to content

Commit 6bf208c

Browse files
authored
[fleet_executor] Parse rank_to_ip map on cpp side and start message bus. (#37126)
1 parent 778a363 commit 6bf208c

File tree

5 files changed

+42
-0
lines changed

5 files changed

+42
-0
lines changed

paddle/fluid/distributed/fleet_executor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE)
1919
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2020
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2121
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
22+
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2223
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2324
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2425
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
16+
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
1617
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
1718
#include "paddle/fluid/framework/program_desc.h"
1819

@@ -31,6 +32,40 @@ FleetExecutor::~FleetExecutor() {
3132

3233
void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) {
3334
// Compile and Initialize
35+
InitMessageBus();
36+
}
37+
38+
void FleetExecutor::InitMessageBus() {
39+
std::stringstream ss;
40+
ss << "\nThe DNS table of the message bus is: \n";
41+
int64_t cur_rank = exe_desc_.cur_rank();
42+
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank;
43+
std::unordered_map<int64_t, std::string> rank_to_addr;
44+
std::string addr;
45+
for (const auto& rank_info : exe_desc_.cluster_info()) {
46+
int64_t rank = rank_info.rank();
47+
std::string ip_port = rank_info.ip_port();
48+
ss << rank << "\t->\t" << ip_port << "\n";
49+
// TODO(Yuang): replace the first 'rank' with real interceptor id
50+
interceptor_id_to_rank.insert(std::make_pair(rank, rank));
51+
rank_to_addr.insert(std::make_pair(rank, ip_port));
52+
if (rank == cur_rank) {
53+
addr = ip_port;
54+
}
55+
}
56+
PADDLE_ENFORCE_NE(
57+
addr, "",
58+
platform::errors::NotFound(
59+
"Current rank is %s, which ip_port cannot be found in the config.",
60+
cur_rank));
61+
VLOG(3) << "Current rank is " << cur_rank << " and the ip_port is " << addr
62+
<< ".";
63+
VLOG(3) << "The number of ranks are " << interceptor_id_to_rank.size() << ".";
64+
VLOG(5) << ss.str();
65+
MessageBus& message_bus_instance = MessageBus::Instance();
66+
if (!message_bus_instance.IsInit()) {
67+
message_bus_instance.Init(interceptor_id_to_rank, rank_to_addr, addr);
68+
}
3469
}
3570

3671
void FleetExecutor::Run() {

paddle/fluid/distributed/fleet_executor/fleet_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class FleetExecutor final {
4242
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
4343
FleetExecutorDesc exe_desc_;
4444
std::unique_ptr<RuntimeGraph> runtime_graph_;
45+
void InitMessageBus();
4546
static std::shared_ptr<Carrier> global_carrier_;
4647
};
4748

paddle/fluid/distributed/fleet_executor/message_bus.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ void MessageBus::Init(
4242
});
4343
}
4444

45+
bool MessageBus::IsInit() const { return is_init_; }
46+
4547
void MessageBus::Release() {
48+
VLOG(3) << "Message bus releases resource.";
4649
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
4750
!defined(PADDLE_WITH_ASCEND_CL)
4851
server_.Stop(1000);

paddle/fluid/distributed/fleet_executor/message_bus.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class MessageBus final {
4848
const std::unordered_map<int64_t, std::string>& rank_to_addr,
4949
const std::string& addr);
5050

51+
bool IsInit() const;
52+
5153
void Release();
5254

5355
// called by Interceptor, send InterceptorMessage to dst

0 commit comments

Comments
 (0)