@@ -27,31 +27,130 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
2727
2828void ComputeInterceptor::PrepareDeps () {
2929 auto & upstream = GetTaskNode ()->upstream ();
30- upstream_deps_.insert (upstream.begin (), upstream.end ());
30+ auto & downstream = GetTaskNode ()->downstream ();
31+
32+ // TODO(wangxi): get from task node
33+ int64_t in_buff_size = std::numeric_limits<int64_t >::max ();
34+ int64_t out_buff_size = 2 ;
35+
36+ for (auto up_id : upstream) {
37+ in_readys_.emplace (up_id, std::make_pair (in_buff_size, 0 ));
38+ }
39+ for (auto down_id : downstream) {
40+ out_buffs_.emplace (down_id, std::make_pair (out_buff_size, 0 ));
41+ }
42+ }
43+
44+ void ComputeInterceptor::IncreaseReady (int64_t up_id) {
45+ auto it = in_readys_.find (up_id);
46+ PADDLE_ENFORCE_NE (it, in_readys_.end (),
47+ platform::errors::NotFound (
48+ " Cannot find upstream=%lld in in_readys." , up_id));
49+
50+ auto max_ready_size = it->second .first ;
51+ auto ready_size = it->second .second ;
52+ ready_size += 1 ;
53+ PADDLE_ENFORCE_LE (ready_size, max_ready_size,
54+ platform::errors::OutOfRange (
55+ " upstream=%lld ready_size must <= max_ready_size, but "
56+ " now ready_size=%lld, max_ready_size=%lld" ,
57+ up_id, ready_size, max_ready_size));
58+ it->second .second = ready_size;
59+ }
60+
61+ void ComputeInterceptor::DecreaseBuff (int64_t down_id) {
62+ auto it = out_buffs_.find (down_id);
63+ PADDLE_ENFORCE_NE (it, out_buffs_.end (),
64+ platform::errors::NotFound (
65+ " Cannot find downstream=%lld in out_buffs." , down_id));
66+ auto used_size = it->second .second ;
67+ used_size -= 1 ;
68+ PADDLE_ENFORCE_GE (
69+ used_size, 0 ,
70+ platform::errors::OutOfRange (
71+ " downstream=%lld used buff size must >= 0, but now equal %lld" ,
72+ down_id, used_size));
73+ it->second .second = used_size;
74+ }
75+
76+ bool ComputeInterceptor::IsInputReady () {
77+ for (auto & ins : in_readys_) {
78+ auto ready_size = ins.second .second ;
79+ // not ready, return false
80+ if (ready_size == 0 ) return false ;
81+ }
82+ return true ;
83+ }
84+
85+ bool ComputeInterceptor::CanWriteOutput () {
86+ for (auto & outs : out_buffs_) {
87+ auto max_buffer_size = outs.second .first ;
88+ auto used_size = outs.second .second ;
89+ // full, return false
90+ if (used_size == max_buffer_size) return false ;
91+ }
92+ return true ;
3193}
3294
3395void ComputeInterceptor::SendDataReadyToDownStream () {
34- auto & downstream = GetTaskNode ()->downstream ();
35- for (auto dst_id : downstream) {
36- InterceptorMessage dst_msg;
37- dst_msg.set_message_type (DATA_IS_READY);
38- VLOG (3 ) << " ComputeInterceptor Send msg to " << dst_id;
39- Send (dst_id, dst_msg);
96+ for (auto & outs : out_buffs_) {
97+ auto down_id = outs.first ;
98+ auto max_buff_size = outs.second .first ;
99+ auto used_size = outs.second .second ;
100+ used_size += 1 ;
101+ PADDLE_ENFORCE_LE (
102+ used_size, max_buff_size,
103+ platform::errors::OutOfRange (" downstream=%lld used buff size must <= "
104+ " max_buff_size, but now used_size=%lld, "
105+ " max_buff_size=%lld" ,
106+ down_id, used_size, max_buff_size));
107+ outs.second .second = used_size;
108+
109+ InterceptorMessage ready_msg;
110+ ready_msg.set_message_type (DATA_IS_READY);
111+ VLOG (3 ) << " ComputeInterceptor Send data_is_ready msg to " << down_id;
112+ Send (down_id, ready_msg);
113+ }
114+ }
115+
116+ void ComputeInterceptor::ReplyCompletedToUpStream () {
117+ for (auto & ins : in_readys_) {
118+ auto up_id = ins.first ;
119+ auto ready_size = ins.second .second ;
120+ ready_size -= 1 ;
121+ PADDLE_ENFORCE_GE (
122+ ready_size, 0 ,
123+ platform::errors::OutOfRange (
124+ " upstream=%lld ready_size must >= 0, but now got %lld" , up_id,
125+ ready_size));
126+ ins.second .second = ready_size;
127+
128+ InterceptorMessage reply_msg;
129+ reply_msg.set_message_type (DATE_IS_USELESS);
130+ VLOG (3 ) << " ComputeInterceptor Reply data_is_useless msg to " << up_id;
131+ Send (up_id, reply_msg);
132+ }
133+ }
134+
135+ void ComputeInterceptor::Run () {
136+ while (IsInputReady () && CanWriteOutput ()) {
137+ VLOG (3 ) << " id=" << GetInterceptorId () << " ComputeInterceptor running" ;
138+ // TODO(wangxi): add op run
139+
140+ // send to downstream and increase buff used
141+ SendDataReadyToDownStream ();
142+ // reply to upstream and decrease ready data
143+ ReplyCompletedToUpStream ();
40144 }
41145}
42146
43147void ComputeInterceptor::Compute (const InterceptorMessage& msg) {
44148 if (msg.message_type () == DATA_IS_READY) {
45- auto src_id = msg.src_id ();
46- upstream_deps_.erase (src_id);
47-
48- // all input is ready
49- if (upstream_deps_.empty ()) {
50- // TODO(wangxi): op run
51- VLOG (3 ) << " id=" << GetInterceptorId () << " ComputeInterceptor running" ;
52- SendDataReadyToDownStream ();
53- PrepareDeps ();
54- }
149+ IncreaseReady (msg.src_id ());
150+ Run ();
151+ } else if (msg.message_type () == DATE_IS_USELESS) {
152+ DecreaseBuff (msg.src_id ());
153+ Run ();
55154 }
56155}
57156
0 commit comments