Skip to content

Commit be3b774

Browse files
authored
[fleet_executor] Complete compute interceptor (#37485)
1 parent 1799c03 commit be3b774

File tree

3 files changed

+161
-23
lines changed

3 files changed

+161
-23
lines changed

paddle/fluid/distributed/fleet_executor/compute_interceptor.cc

Lines changed: 116 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,130 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
2727

2828
void ComputeInterceptor::PrepareDeps() {
2929
auto& upstream = GetTaskNode()->upstream();
30-
upstream_deps_.insert(upstream.begin(), upstream.end());
30+
auto& downstream = GetTaskNode()->downstream();
31+
32+
// TODO(wangxi): get from task node
33+
int64_t in_buff_size = std::numeric_limits<int64_t>::max();
34+
int64_t out_buff_size = 2;
35+
36+
for (auto up_id : upstream) {
37+
in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0));
38+
}
39+
for (auto down_id : downstream) {
40+
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
41+
}
42+
}
43+
44+
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
45+
auto it = in_readys_.find(up_id);
46+
PADDLE_ENFORCE_NE(it, in_readys_.end(),
47+
platform::errors::NotFound(
48+
"Cannot find upstream=%lld in in_readys.", up_id));
49+
50+
auto max_ready_size = it->second.first;
51+
auto ready_size = it->second.second;
52+
ready_size += 1;
53+
PADDLE_ENFORCE_LE(ready_size, max_ready_size,
54+
platform::errors::OutOfRange(
55+
"upstream=%lld ready_size must <= max_ready_size, but "
56+
"now ready_size=%lld, max_ready_size=%lld",
57+
up_id, ready_size, max_ready_size));
58+
it->second.second = ready_size;
59+
}
60+
61+
void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
62+
auto it = out_buffs_.find(down_id);
63+
PADDLE_ENFORCE_NE(it, out_buffs_.end(),
64+
platform::errors::NotFound(
65+
"Cannot find downstream=%lld in out_buffs.", down_id));
66+
auto used_size = it->second.second;
67+
used_size -= 1;
68+
PADDLE_ENFORCE_GE(
69+
used_size, 0,
70+
platform::errors::OutOfRange(
71+
"downstream=%lld used buff size must >= 0, but now equal %lld",
72+
down_id, used_size));
73+
it->second.second = used_size;
74+
}
75+
76+
bool ComputeInterceptor::IsInputReady() {
77+
for (auto& ins : in_readys_) {
78+
auto ready_size = ins.second.second;
79+
// not ready, return false
80+
if (ready_size == 0) return false;
81+
}
82+
return true;
83+
}
84+
85+
bool ComputeInterceptor::CanWriteOutput() {
86+
for (auto& outs : out_buffs_) {
87+
auto max_buffer_size = outs.second.first;
88+
auto used_size = outs.second.second;
89+
// full, return false
90+
if (used_size == max_buffer_size) return false;
91+
}
92+
return true;
3193
}
3294

3395
void ComputeInterceptor::SendDataReadyToDownStream() {
34-
auto& downstream = GetTaskNode()->downstream();
35-
for (auto dst_id : downstream) {
36-
InterceptorMessage dst_msg;
37-
dst_msg.set_message_type(DATA_IS_READY);
38-
VLOG(3) << "ComputeInterceptor Send msg to " << dst_id;
39-
Send(dst_id, dst_msg);
96+
for (auto& outs : out_buffs_) {
97+
auto down_id = outs.first;
98+
auto max_buff_size = outs.second.first;
99+
auto used_size = outs.second.second;
100+
used_size += 1;
101+
PADDLE_ENFORCE_LE(
102+
used_size, max_buff_size,
103+
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
104+
"max_buff_size, but now used_size=%lld, "
105+
"max_buff_size=%lld",
106+
down_id, used_size, max_buff_size));
107+
outs.second.second = used_size;
108+
109+
InterceptorMessage ready_msg;
110+
ready_msg.set_message_type(DATA_IS_READY);
111+
VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id;
112+
Send(down_id, ready_msg);
113+
}
114+
}
115+
116+
void ComputeInterceptor::ReplyCompletedToUpStream() {
117+
for (auto& ins : in_readys_) {
118+
auto up_id = ins.first;
119+
auto ready_size = ins.second.second;
120+
ready_size -= 1;
121+
PADDLE_ENFORCE_GE(
122+
ready_size, 0,
123+
platform::errors::OutOfRange(
124+
"upstream=%lld ready_size must >= 0, but now got %lld", up_id,
125+
ready_size));
126+
ins.second.second = ready_size;
127+
128+
InterceptorMessage reply_msg;
129+
reply_msg.set_message_type(DATE_IS_USELESS);
130+
VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id;
131+
Send(up_id, reply_msg);
132+
}
133+
}
134+
135+
void ComputeInterceptor::Run() {
136+
while (IsInputReady() && CanWriteOutput()) {
137+
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
138+
// TODO(wangxi): add op run
139+
140+
// send to downstream and increase buff used
141+
SendDataReadyToDownStream();
142+
// reply to upstream and decrease ready data
143+
ReplyCompletedToUpStream();
40144
}
41145
}
42146

43147
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
44148
if (msg.message_type() == DATA_IS_READY) {
45-
auto src_id = msg.src_id();
46-
upstream_deps_.erase(src_id);
47-
48-
// all input is ready
49-
if (upstream_deps_.empty()) {
50-
// TODO(wangxi): op run
51-
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
52-
SendDataReadyToDownStream();
53-
PrepareDeps();
54-
}
149+
IncreaseReady(msg.src_id());
150+
Run();
151+
} else if (msg.message_type() == DATE_IS_USELESS) {
152+
DecreaseBuff(msg.src_id());
153+
Run();
55154
}
56155
}
57156

paddle/fluid/distributed/fleet_executor/compute_interceptor.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <utility>
18+
1719
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
1820

1921
namespace paddle {
@@ -25,12 +27,24 @@ class ComputeInterceptor : public Interceptor {
2527

2628
void PrepareDeps();
2729

30+
void IncreaseReady(int64_t up_id);
31+
void DecreaseBuff(int64_t down_id);
32+
bool IsInputReady();
33+
bool CanWriteOutput();
34+
2835
void SendDataReadyToDownStream();
36+
void ReplyCompletedToUpStream();
2937

38+
void Run();
3039
void Compute(const InterceptorMessage& msg);
3140

3241
private:
33-
std::unordered_set<int64_t> upstream_deps_;
42+
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
43+
int64_t step_{0};
44+
// upstream_id-->(max_ready_size, ready_size)
45+
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
46+
// downstream_id-->(max_buffer_size, used_size)
47+
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
3448
};
3549

3650
} // namespace distributed

paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,39 +35,64 @@ class StopInterceptor : public Interceptor {
3535
void Stop(const InterceptorMessage& msg) {
3636
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
3737
<< std::endl;
38+
count_ += 1;
39+
if (count_ == 1) return;
3840
InterceptorMessage stop;
3941
stop.set_message_type(STOP);
4042
Send(0, stop);
4143
Send(1, stop);
4244
Send(2, stop);
45+
Send(3, stop);
46+
}
47+
int count_{0};
48+
};
49+
50+
class StartInterceptor : public Interceptor {
51+
public:
52+
StartInterceptor(int64_t interceptor_id, TaskNode* node)
53+
: Interceptor(interceptor_id, node) {
54+
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
55+
}
56+
57+
void NOP(const InterceptorMessage& msg) {
58+
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
59+
<< std::endl;
4360
}
4461
};
4562

4663
TEST(ComputeInterceptor, Compute) {
4764
MessageBus& msg_bus = MessageBus::Instance();
48-
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
65+
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}},
66+
"127.0.0.0:0");
4967

5068
Carrier& carrier = Carrier::Instance();
5169

5270
// NOTE: don't delete, otherwise interceptor will use undefined node
5371
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
5472
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
5573
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
74+
TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0);
5675

57-
// a->b->c
76+
// a->b->c->d
5877
node_a->AddDownstreamTask(1);
5978
node_b->AddUpstreamTask(0);
6079
node_b->AddDownstreamTask(2);
80+
node_c->AddUpstreamTask(1);
81+
node_c->AddDownstreamTask(3);
82+
node_d->AddUpstreamTask(2);
6183

62-
Interceptor* a = carrier.SetInterceptor(
63-
0, InterceptorFactory::Create("Compute", 0, node_a));
84+
Interceptor* a =
85+
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
6486
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
65-
carrier.SetInterceptor(2, std::make_unique<StopInterceptor>(2, node_c));
87+
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
88+
carrier.SetInterceptor(3, std::make_unique<StopInterceptor>(3, node_c));
6689

6790
carrier.SetCreatingFlag(false);
6891

6992
InterceptorMessage msg;
7093
msg.set_message_type(DATA_IS_READY);
94+
// double buff, send twice
95+
a->Send(1, msg);
7196
a->Send(1, msg);
7297
}
7398

0 commit comments

Comments
 (0)