@@ -61,6 +61,10 @@ int32_t GraphBrpcServer::initialize() {
6161 return 0 ;
6262}
6363
64+ brpc::Channel *GraphBrpcServer::get_cmd_channel (size_t server_index) {
65+ return _pserver_channels[server_index].get ();
66+ }
67+
6468uint64_t GraphBrpcServer::start (const std::string &ip, uint32_t port) {
6569 std::unique_lock<std::mutex> lock (mutex_);
6670
@@ -80,6 +84,42 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
8084 return 0 ;
8185}
8286
87+ int32_t GraphBrpcServer::build_peer2peer_connection (int rank) {
88+ this ->rank = rank;
89+ auto _env = environment ();
90+ brpc::ChannelOptions options;
91+ options.protocol = " baidu_std" ;
92+ options.timeout_ms = 500000 ;
93+ options.connection_type = " pooled" ;
94+ options.connect_timeout_ms = 10000 ;
95+ options.max_retry = 3 ;
96+
97+ std::vector<PSHost> server_list = _env->get_ps_servers ();
98+ _pserver_channels.resize (server_list.size ());
99+ std::ostringstream os;
100+ std::string server_ip_port;
101+ for (size_t i = 0 ; i < server_list.size (); ++i) {
102+ server_ip_port.assign (server_list[i].ip .c_str ());
103+ server_ip_port.append (" :" );
104+ server_ip_port.append (std::to_string (server_list[i].port ));
105+ _pserver_channels[i].reset (new brpc::Channel ());
106+ if (_pserver_channels[i]->Init (server_ip_port.c_str (), " " , &options) != 0 ) {
107+ VLOG (0 ) << " GraphServer connect to Server:" << server_ip_port
108+ << " Failed! Try again." ;
109+ std::string int_ip_port =
110+ GetIntTypeEndpoint (server_list[i].ip , server_list[i].port );
111+ if (_pserver_channels[i]->Init (int_ip_port.c_str (), " " , &options) != 0 ) {
112+ LOG (ERROR) << " GraphServer connect to Server:" << int_ip_port
113+ << " Failed!" ;
114+ return -1 ;
115+ }
116+ }
117+ os << server_ip_port << " ," ;
118+ }
119+ LOG (INFO) << " servers peer2peer connection success:" << os.str ();
120+ return 0 ;
121+ }
122+
83123int32_t GraphBrpcService::clear_nodes (Table *table,
84124 const PsRequestMessage &request,
85125 PsResponseMessage &response,
@@ -160,6 +200,9 @@ int32_t GraphBrpcService::initialize() {
160200 &GraphBrpcService::remove_graph_node;
161201 _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
162202 &GraphBrpcService::graph_set_node_feat;
203+ _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
204+ &GraphBrpcService::sample_neighboors_across_multi_servers;
205+
163206 // shard初始化,server启动后才可从env获取到server_list的shard信息
164207 initialize_shard_info ();
165208
@@ -172,10 +215,10 @@ int32_t GraphBrpcService::initialize_shard_info() {
172215 if (_is_initialize_shard_info) {
173216 return 0 ;
174217 }
175- size_t shard_num = _server->environment ()->get_ps_servers ().size ();
218+ server_size = _server->environment ()->get_ps_servers ().size ();
176219 auto &table_map = *(_server->table ());
177220 for (auto itr : table_map) {
178- itr.second ->set_shard (_rank, shard_num );
221+ itr.second ->set_shard (_rank, server_size );
179222 }
180223 _is_initialize_shard_info = true ;
181224 }
@@ -209,7 +252,9 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
209252 int service_ret = (this ->*handler_func)(table, *request, *response, cntl);
210253 if (service_ret != 0 ) {
211254 response->set_err_code (service_ret);
212- response->set_err_msg (" server internal error" );
255+ if (!response->has_err_msg ()) {
256+ response->set_err_msg (" server internal error" );
257+ }
213258 }
214259}
215260
@@ -403,7 +448,156 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
403448
404449 return 0 ;
405450}
406-
451+ int32_t GraphBrpcService::sample_neighboors_across_multi_servers (
452+ Table *table, const PsRequestMessage &request, PsResponseMessage &response,
453+ brpc::Controller *cntl) {
454+ // sleep(5);
455+ CHECK_TABLE_EXIST (table, request, response)
456+ if (request.params_size () < 2 ) {
457+ set_response_code (
458+ response, -1 ,
459+ " graph_random_sample request requires at least 2 arguments" );
460+ return 0 ;
461+ }
462+ size_t node_num = request.params (0 ).size () / sizeof (uint64_t ),
463+ size_of_size_t = sizeof (size_t );
464+ uint64_t *node_data = (uint64_t *)(request.params (0 ).c_str ());
465+ int sample_size = *(uint64_t *)(request.params (1 ).c_str ());
466+ // std::vector<uint64_t> res = ((GraphTable
467+ // *)table).filter_out_non_exist_nodes(node_data, sample_size);
468+ std::vector<int > request2server;
469+ std::vector<int > server2request (server_size, -1 );
470+ std::vector<uint64_t > local_id;
471+ std::vector<int > local_query_idx;
472+ size_t rank = get_rank ();
473+ for (int query_idx = 0 ; query_idx < node_num; ++query_idx) {
474+ int server_index =
475+ ((GraphTable *)table)->get_server_index_by_id (node_data[query_idx]);
476+ if (server2request[server_index] == -1 ) {
477+ server2request[server_index] = request2server.size ();
478+ request2server.push_back (server_index);
479+ }
480+ }
481+ if (server2request[rank] != -1 ) {
482+ auto pos = server2request[rank];
483+ std::swap (request2server[pos],
484+ request2server[(int )request2server.size () - 1 ]);
485+ server2request[request2server[pos]] = pos;
486+ server2request[request2server[(int )request2server.size () - 1 ]] =
487+ request2server.size () - 1 ;
488+ }
489+ size_t request_call_num = request2server.size ();
490+ std::vector<std::unique_ptr<char []>> local_buffers;
491+ std::vector<int > local_actual_sizes;
492+ std::vector<size_t > seq;
493+ std::vector<std::vector<uint64_t >> node_id_buckets (request_call_num);
494+ std::vector<std::vector<int >> query_idx_buckets (request_call_num);
495+ for (int query_idx = 0 ; query_idx < node_num; ++query_idx) {
496+ int server_index =
497+ ((GraphTable *)table)->get_server_index_by_id (node_data[query_idx]);
498+ int request_idx = server2request[server_index];
499+ node_id_buckets[request_idx].push_back (node_data[query_idx]);
500+ query_idx_buckets[request_idx].push_back (query_idx);
501+ seq.push_back (request_idx);
502+ }
503+ size_t remote_call_num = request_call_num;
504+ if (request2server.size () != 0 && request2server.back () == rank) {
505+ remote_call_num--;
506+ local_buffers.resize (node_id_buckets.back ().size ());
507+ local_actual_sizes.resize (node_id_buckets.back ().size ());
508+ }
509+ cntl->response_attachment ().append (&node_num, sizeof (size_t ));
510+ auto local_promise = std::make_shared<std::promise<int32_t >>();
511+ std::future<int > local_fut = local_promise->get_future ();
512+ std::vector<bool > failed (server_size, false );
513+ std::function<void (void *)> func = [&, node_id_buckets, query_idx_buckets,
514+ request_call_num](void *done) {
515+ local_fut.get ();
516+ std::vector<int > actual_size;
517+ auto *closure = (DownpourBrpcClosure *)done;
518+ std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res (
519+ remote_call_num);
520+ size_t fail_num = 0 ;
521+ for (size_t request_idx = 0 ; request_idx < remote_call_num; ++request_idx) {
522+ if (closure->check_response (request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) !=
523+ 0 ) {
524+ ++fail_num;
525+ failed[request2server[request_idx]] = true ;
526+ } else {
527+ auto &res_io_buffer = closure->cntl (request_idx)->response_attachment ();
528+ size_t node_size;
529+ res[request_idx].reset (new butil::IOBufBytesIterator (res_io_buffer));
530+ size_t num;
531+ res[request_idx]->copy_and_forward (&num, sizeof (size_t ));
532+ }
533+ }
534+ int size;
535+ int local_index = 0 ;
536+ for (size_t i = 0 ; i < node_num; i++) {
537+ if (fail_num > 0 && failed[seq[i]]) {
538+ size = 0 ;
539+ } else if (request2server[seq[i]] != rank) {
540+ res[seq[i]]->copy_and_forward (&size, sizeof (int ));
541+ } else {
542+ size = local_actual_sizes[local_index++];
543+ }
544+ actual_size.push_back (size);
545+ }
546+ cntl->response_attachment ().append (actual_size.data (),
547+ actual_size.size () * sizeof (int ));
548+
549+ local_index = 0 ;
550+ for (size_t i = 0 ; i < node_num; i++) {
551+ if (fail_num > 0 && failed[seq[i]]) {
552+ continue ;
553+ } else if (request2server[seq[i]] != rank) {
554+ char temp[actual_size[i] + 1 ];
555+ res[seq[i]]->copy_and_forward (temp, actual_size[i]);
556+ cntl->response_attachment ().append (temp, actual_size[i]);
557+ } else {
558+ char *temp = local_buffers[local_index++].get ();
559+ cntl->response_attachment ().append (temp, actual_size[i]);
560+ }
561+ }
562+ closure->set_promise_value (0 );
563+ };
564+
565+ DownpourBrpcClosure *closure = new DownpourBrpcClosure (remote_call_num, func);
566+
567+ auto promise = std::make_shared<std::promise<int32_t >>();
568+ closure->add_promise (promise);
569+ std::future<int > fut = promise->get_future ();
570+
571+ for (int request_idx = 0 ; request_idx < remote_call_num; ++request_idx) {
572+ int server_index = request2server[request_idx];
573+ closure->request (request_idx)->set_cmd_id (PS_GRAPH_SAMPLE_NEIGHBOORS);
574+ closure->request (request_idx)->set_table_id (request.table_id ());
575+ closure->request (request_idx)->set_client_id (rank);
576+ size_t node_num = node_id_buckets[request_idx].size ();
577+
578+ closure->request (request_idx)
579+ ->add_params ((char *)node_id_buckets[request_idx].data (),
580+ sizeof (uint64_t ) * node_num);
581+ closure->request (request_idx)
582+ ->add_params ((char *)&sample_size, sizeof (int ));
583+ PsService_Stub rpc_stub (
584+ ((GraphBrpcServer *)get_server ())->get_cmd_channel (server_index));
585+ // GraphPsService_Stub rpc_stub =
586+ // getServiceStub(get_cmd_channel(server_index));
587+ closure->cntl (request_idx)->set_log_id (butil::gettimeofday_ms ());
588+ rpc_stub.service (closure->cntl (request_idx), closure->request (request_idx),
589+ closure->response (request_idx), closure);
590+ }
591+ if (server2request[rank] != -1 ) {
592+ ((GraphTable *)table)
593+ ->random_sample_neighboors (node_id_buckets.back ().data (), sample_size,
594+ local_buffers, local_actual_sizes);
595+ }
596+ local_promise.get ()->set_value (0 );
597+ if (remote_call_num == 0 ) func (closure);
598+ fut.get ();
599+ return 0 ;
600+ }
407601int32_t GraphBrpcService::graph_set_node_feat (Table *table,
408602 const PsRequestMessage &request,
409603 PsResponseMessage &response,
@@ -412,7 +606,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
412606 if (request.params_size () < 3 ) {
413607 set_response_code (
414608 response, -1 ,
415- " graph_set_node_feat request requires at least 2 arguments" );
609+ " graph_set_node_feat request requires at least 3 arguments" );
416610 return 0 ;
417611 }
418612 size_t node_num = request.params (0 ).size () / sizeof (uint64_t );
0 commit comments