Skip to content

Commit 5eb640c

Browse files
authored
Graph engine4 (#36587)
1 parent d64f7b3 commit 5eb640c

File tree

10 files changed

+292
-16
lines changed

10 files changed

+292
-16
lines changed

paddle/fluid/distributed/service/graph_brpc_client.cc

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,63 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
304304
// char* &buffer,int &actual_size
305305
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
306306
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
307-
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
307+
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
308+
int server_index) {
309+
if (server_index != -1) {
310+
res.resize(node_ids.size());
311+
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
312+
int ret = 0;
313+
auto *closure = (DownpourBrpcClosure *)done;
314+
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) !=
315+
0) {
316+
ret = -1;
317+
} else {
318+
auto &res_io_buffer = closure->cntl(0)->response_attachment();
319+
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
320+
size_t bytes_size = io_buffer_itr.bytes_left();
321+
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
322+
char *buffer = buffer_wrapper.get();
323+
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
324+
325+
size_t node_num = *(size_t *)buffer;
326+
int *actual_sizes = (int *)(buffer + sizeof(size_t));
327+
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;
328+
329+
int offset = 0;
330+
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
331+
int actual_size = actual_sizes[node_idx];
332+
int start = 0;
333+
while (start < actual_size) {
334+
res[node_idx].push_back(
335+
{*(uint64_t *)(node_buffer + offset + start),
336+
*(float *)(node_buffer + offset + start +
337+
GraphNode::id_size)});
338+
start += GraphNode::id_size + GraphNode::weight_size;
339+
}
340+
offset += actual_size;
341+
}
342+
}
343+
closure->set_promise_value(ret);
344+
});
345+
auto promise = std::make_shared<std::promise<int32_t>>();
346+
closure->add_promise(promise);
347+
std::future<int> fut = promise->get_future();
348+
;
349+
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
350+
closure->request(0)->set_table_id(table_id);
351+
closure->request(0)->set_client_id(_client_id);
352+
closure->request(0)->add_params((char *)node_ids.data(),
353+
sizeof(uint64_t) * node_ids.size());
354+
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
355+
;
356+
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
357+
GraphPsService_Stub rpc_stub =
358+
getServiceStub(get_cmd_channel(server_index));
359+
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
360+
rpc_stub.service(closure->cntl(0), closure->request(0),
361+
closure->response(0), closure);
362+
return fut;
363+
}
308364
std::vector<int> request2server;
309365
std::vector<int> server2request(server_size, -1);
310366
res.clear();

paddle/fluid/distributed/service/graph_brpc_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
6464
// given a batch of nodes, sample graph_neighboors for each of them
6565
virtual std::future<int32_t> batch_sample_neighboors(
6666
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
67-
std::vector<std::vector<std::pair<uint64_t, float>>>& res);
67+
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
68+
int server_index = -1);
6869

6970
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
7071
int server_index, int start,

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 199 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ int32_t GraphBrpcServer::initialize() {
6161
return 0;
6262
}
6363

64+
brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) {
65+
return _pserver_channels[server_index].get();
66+
}
67+
6468
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
6569
std::unique_lock<std::mutex> lock(mutex_);
6670

@@ -80,6 +84,42 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
8084
return 0;
8185
}
8286

87+
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
88+
this->rank = rank;
89+
auto _env = environment();
90+
brpc::ChannelOptions options;
91+
options.protocol = "baidu_std";
92+
options.timeout_ms = 500000;
93+
options.connection_type = "pooled";
94+
options.connect_timeout_ms = 10000;
95+
options.max_retry = 3;
96+
97+
std::vector<PSHost> server_list = _env->get_ps_servers();
98+
_pserver_channels.resize(server_list.size());
99+
std::ostringstream os;
100+
std::string server_ip_port;
101+
for (size_t i = 0; i < server_list.size(); ++i) {
102+
server_ip_port.assign(server_list[i].ip.c_str());
103+
server_ip_port.append(":");
104+
server_ip_port.append(std::to_string(server_list[i].port));
105+
_pserver_channels[i].reset(new brpc::Channel());
106+
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
107+
VLOG(0) << "GraphServer connect to Server:" << server_ip_port
108+
<< " Failed! Try again.";
109+
std::string int_ip_port =
110+
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
111+
if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
112+
LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
113+
<< " Failed!";
114+
return -1;
115+
}
116+
}
117+
os << server_ip_port << ",";
118+
}
119+
LOG(INFO) << "servers peer2peer connection success:" << os.str();
120+
return 0;
121+
}
122+
83123
int32_t GraphBrpcService::clear_nodes(Table *table,
84124
const PsRequestMessage &request,
85125
PsResponseMessage &response,
@@ -160,6 +200,9 @@ int32_t GraphBrpcService::initialize() {
160200
&GraphBrpcService::remove_graph_node;
161201
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
162202
&GraphBrpcService::graph_set_node_feat;
203+
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
204+
&GraphBrpcService::sample_neighboors_across_multi_servers;
205+
163206
// shard初始化,server启动后才可从env获取到server_list的shard信息
164207
initialize_shard_info();
165208

@@ -172,10 +215,10 @@ int32_t GraphBrpcService::initialize_shard_info() {
172215
if (_is_initialize_shard_info) {
173216
return 0;
174217
}
175-
size_t shard_num = _server->environment()->get_ps_servers().size();
218+
server_size = _server->environment()->get_ps_servers().size();
176219
auto &table_map = *(_server->table());
177220
for (auto itr : table_map) {
178-
itr.second->set_shard(_rank, shard_num);
221+
itr.second->set_shard(_rank, server_size);
179222
}
180223
_is_initialize_shard_info = true;
181224
}
@@ -209,7 +252,9 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
209252
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
210253
if (service_ret != 0) {
211254
response->set_err_code(service_ret);
212-
response->set_err_msg("server internal error");
255+
if (!response->has_err_msg()) {
256+
response->set_err_msg("server internal error");
257+
}
213258
}
214259
}
215260

@@ -403,7 +448,156 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
403448

404449
return 0;
405450
}
406-
451+
int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
452+
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
453+
brpc::Controller *cntl) {
454+
// sleep(5);
455+
CHECK_TABLE_EXIST(table, request, response)
456+
if (request.params_size() < 2) {
457+
set_response_code(
458+
response, -1,
459+
"graph_random_sample request requires at least 2 arguments");
460+
return 0;
461+
}
462+
size_t node_num = request.params(0).size() / sizeof(uint64_t),
463+
size_of_size_t = sizeof(size_t);
464+
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
465+
int sample_size = *(uint64_t *)(request.params(1).c_str());
466+
// std::vector<uint64_t> res = ((GraphTable
467+
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
468+
std::vector<int> request2server;
469+
std::vector<int> server2request(server_size, -1);
470+
std::vector<uint64_t> local_id;
471+
std::vector<int> local_query_idx;
472+
size_t rank = get_rank();
473+
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
474+
int server_index =
475+
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
476+
if (server2request[server_index] == -1) {
477+
server2request[server_index] = request2server.size();
478+
request2server.push_back(server_index);
479+
}
480+
}
481+
if (server2request[rank] != -1) {
482+
auto pos = server2request[rank];
483+
std::swap(request2server[pos],
484+
request2server[(int)request2server.size() - 1]);
485+
server2request[request2server[pos]] = pos;
486+
server2request[request2server[(int)request2server.size() - 1]] =
487+
request2server.size() - 1;
488+
}
489+
size_t request_call_num = request2server.size();
490+
std::vector<std::unique_ptr<char[]>> local_buffers;
491+
std::vector<int> local_actual_sizes;
492+
std::vector<size_t> seq;
493+
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
494+
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
495+
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
496+
int server_index =
497+
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
498+
int request_idx = server2request[server_index];
499+
node_id_buckets[request_idx].push_back(node_data[query_idx]);
500+
query_idx_buckets[request_idx].push_back(query_idx);
501+
seq.push_back(request_idx);
502+
}
503+
size_t remote_call_num = request_call_num;
504+
if (request2server.size() != 0 && request2server.back() == rank) {
505+
remote_call_num--;
506+
local_buffers.resize(node_id_buckets.back().size());
507+
local_actual_sizes.resize(node_id_buckets.back().size());
508+
}
509+
cntl->response_attachment().append(&node_num, sizeof(size_t));
510+
auto local_promise = std::make_shared<std::promise<int32_t>>();
511+
std::future<int> local_fut = local_promise->get_future();
512+
std::vector<bool> failed(server_size, false);
513+
std::function<void(void *)> func = [&, node_id_buckets, query_idx_buckets,
514+
request_call_num](void *done) {
515+
local_fut.get();
516+
std::vector<int> actual_size;
517+
auto *closure = (DownpourBrpcClosure *)done;
518+
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
519+
remote_call_num);
520+
size_t fail_num = 0;
521+
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
522+
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) !=
523+
0) {
524+
++fail_num;
525+
failed[request2server[request_idx]] = true;
526+
} else {
527+
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
528+
size_t node_size;
529+
res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
530+
size_t num;
531+
res[request_idx]->copy_and_forward(&num, sizeof(size_t));
532+
}
533+
}
534+
int size;
535+
int local_index = 0;
536+
for (size_t i = 0; i < node_num; i++) {
537+
if (fail_num > 0 && failed[seq[i]]) {
538+
size = 0;
539+
} else if (request2server[seq[i]] != rank) {
540+
res[seq[i]]->copy_and_forward(&size, sizeof(int));
541+
} else {
542+
size = local_actual_sizes[local_index++];
543+
}
544+
actual_size.push_back(size);
545+
}
546+
cntl->response_attachment().append(actual_size.data(),
547+
actual_size.size() * sizeof(int));
548+
549+
local_index = 0;
550+
for (size_t i = 0; i < node_num; i++) {
551+
if (fail_num > 0 && failed[seq[i]]) {
552+
continue;
553+
} else if (request2server[seq[i]] != rank) {
554+
char temp[actual_size[i] + 1];
555+
res[seq[i]]->copy_and_forward(temp, actual_size[i]);
556+
cntl->response_attachment().append(temp, actual_size[i]);
557+
} else {
558+
char *temp = local_buffers[local_index++].get();
559+
cntl->response_attachment().append(temp, actual_size[i]);
560+
}
561+
}
562+
closure->set_promise_value(0);
563+
};
564+
565+
DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);
566+
567+
auto promise = std::make_shared<std::promise<int32_t>>();
568+
closure->add_promise(promise);
569+
std::future<int> fut = promise->get_future();
570+
571+
for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
572+
int server_index = request2server[request_idx];
573+
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
574+
closure->request(request_idx)->set_table_id(request.table_id());
575+
closure->request(request_idx)->set_client_id(rank);
576+
size_t node_num = node_id_buckets[request_idx].size();
577+
578+
closure->request(request_idx)
579+
->add_params((char *)node_id_buckets[request_idx].data(),
580+
sizeof(uint64_t) * node_num);
581+
closure->request(request_idx)
582+
->add_params((char *)&sample_size, sizeof(int));
583+
PsService_Stub rpc_stub(
584+
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
585+
// GraphPsService_Stub rpc_stub =
586+
// getServiceStub(get_cmd_channel(server_index));
587+
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
588+
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
589+
closure->response(request_idx), closure);
590+
}
591+
if (server2request[rank] != -1) {
592+
((GraphTable *)table)
593+
->random_sample_neighboors(node_id_buckets.back().data(), sample_size,
594+
local_buffers, local_actual_sizes);
595+
}
596+
local_promise.get()->set_value(0);
597+
if (remote_call_num == 0) func(closure);
598+
fut.get();
599+
return 0;
600+
}
407601
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
408602
const PsRequestMessage &request,
409603
PsResponseMessage &response,
@@ -412,7 +606,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
412606
if (request.params_size() < 3) {
413607
set_response_code(
414608
response, -1,
415-
"graph_set_node_feat request requires at least 2 arguments");
609+
"graph_set_node_feat request requires at least 3 arguments");
416610
return 0;
417611
}
418612
size_t node_num = request.params(0).size() / sizeof(uint64_t);

paddle/fluid/distributed/service/graph_brpc_server.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class GraphBrpcServer : public PSServer {
3232
virtual ~GraphBrpcServer() {}
3333
PsBaseService *get_service() { return _service.get(); }
3434
virtual uint64_t start(const std::string &ip, uint32_t port);
35+
virtual int32_t build_peer2peer_connection(int rank);
36+
virtual brpc::Channel *get_cmd_channel(size_t server_index);
3537
virtual int32_t stop() {
3638
std::unique_lock<std::mutex> lock(mutex_);
3739
if (stoped_) return 0;
@@ -50,6 +52,7 @@ class GraphBrpcServer : public PSServer {
5052
mutable std::mutex mutex_;
5153
std::condition_variable cv_;
5254
bool stoped_ = false;
55+
int rank;
5356
brpc::Server _server;
5457
std::shared_ptr<PsBaseService> _service;
5558
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
@@ -113,12 +116,18 @@ class GraphBrpcService : public PsBaseService {
113116
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
114117
PsResponseMessage &response, brpc::Controller *cntl);
115118

119+
int32_t sample_neighboors_across_multi_servers(
120+
Table *table, const PsRequestMessage &request,
121+
PsResponseMessage &response, brpc::Controller *cntl);
122+
116123
private:
117124
bool _is_initialize_shard_info;
118125
std::mutex _initialize_shard_mutex;
119126
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
120127
std::vector<float> _ori_values;
121128
const int sample_nodes_ranges = 23;
129+
size_t server_size;
130+
std::shared_ptr<::ThreadPool> task_pool;
122131
};
123132

124133
} // namespace distributed

paddle/fluid/distributed/service/graph_py_service.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void GraphPyServer::start_server(bool block) {
107107
empty_vec.push_back(empty_prog);
108108
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
109109
pserver_ptr->start(ip, port);
110+
pserver_ptr->build_peer2peer_connection(rank);
110111
std::condition_variable* cv_ = pserver_ptr->export_cv();
111112
if (block) {
112113
std::mutex mutex_;

paddle/fluid/distributed/service/sendrecv.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ enum PsCmdID {
5656
PS_GRAPH_ADD_GRAPH_NODE = 35;
5757
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
5858
PS_GRAPH_SET_NODE_FEAT = 37;
59+
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
5960
}
6061

6162
message PsRequestMessage {

paddle/fluid/distributed/service/server.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class PsBaseService : public PsService {
147147
public:
148148
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
149149
virtual ~PsBaseService() {}
150-
150+
virtual size_t get_rank() { return _rank; }
151151
virtual int32_t configure(PSServer *server) {
152152
_server = server;
153153
_rank = _server->rank();
@@ -167,6 +167,7 @@ class PsBaseService : public PsService {
167167
}
168168

169169
virtual int32_t initialize() = 0;
170+
PSServer *get_server() { return _server; }
170171

171172
protected:
172173
size_t _rank;

0 commit comments

Comments
 (0)