@@ -304,10 +304,15 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
304304// char* &buffer,int &actual_size
305305std::future<int32_t > GraphBrpcClient::batch_sample_neighbors (
306306 uint32_t table_id, std::vector<uint64_t > node_ids, int sample_size,
307- std::vector<std::vector<std::pair<uint64_t , float >>> &res,
307+ // std::vector<std::vector<std::pair<uint64_t, float>>> &res,
308+ std::vector<std::vector<uint64_t >> &res,
309+ std::vector<std::vector<float >> &res_weight, bool need_weight,
308310 int server_index) {
309311 if (server_index != -1 ) {
310312 res.resize (node_ids.size ());
313+ if (need_weight) {
314+ res_weight.resize (node_ids.size ());
315+ }
311316 DownpourBrpcClosure *closure = new DownpourBrpcClosure (1 , [&](void *done) {
312317 int ret = 0 ;
313318 auto *closure = (DownpourBrpcClosure *)done;
@@ -331,11 +336,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
331336 int actual_size = actual_sizes[node_idx];
332337 int start = 0 ;
333338 while (start < actual_size) {
334- res[node_idx].push_back (
335- {*(uint64_t *)(node_buffer + offset + start),
336- *(float *)(node_buffer + offset + start +
337- GraphNode::id_size)});
338- start += GraphNode::id_size + GraphNode::weight_size;
339+ res[node_idx].emplace_back (
340+ *(uint64_t *)(node_buffer + offset + start));
341+ start += GraphNode::id_size;
342+ if (need_weight) {
343+ res_weight[node_idx].emplace_back (
344+ *(float *)(node_buffer + offset + start));
345+ start += GraphNode::weight_size;
346+ }
339347 }
340348 offset += actual_size;
341349 }
@@ -352,6 +360,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
352360 closure->request (0 )->add_params ((char *)node_ids.data (),
353361 sizeof (uint64_t ) * node_ids.size ());
354362 closure->request (0 )->add_params ((char *)&sample_size, sizeof (int ));
363+ closure->request (0 )->add_params ((char *)&need_weight, sizeof (bool ));
355364 ;
356365 // PsService_Stub rpc_stub(get_cmd_channel(server_index));
357366 GraphPsService_Stub rpc_stub =
@@ -364,13 +373,18 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
364373 std::vector<int > request2server;
365374 std::vector<int > server2request (server_size, -1 );
366375 res.clear ();
376+ res_weight.clear ();
367377 for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx) {
368378 int server_index = get_server_index_by_id (node_ids[query_idx]);
369379 if (server2request[server_index] == -1 ) {
370380 server2request[server_index] = request2server.size ();
371381 request2server.push_back (server_index);
372382 }
373- res.push_back (std::vector<std::pair<uint64_t , float >>());
383+ // res.push_back(std::vector<std::pair<uint64_t, float>>());
384+ res.push_back ({});
385+ if (need_weight) {
386+ res_weight.push_back ({});
387+ }
374388 }
375389 size_t request_call_num = request2server.size ();
376390 std::vector<std::vector<uint64_t >> node_id_buckets (request_call_num);
@@ -413,11 +427,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
413427 int actual_size = actual_sizes[node_idx];
414428 int start = 0 ;
415429 while (start < actual_size) {
416- res[query_idx].push_back (
417- {*(uint64_t *)(node_buffer + offset + start),
418- *(float *)(node_buffer + offset + start +
419- GraphNode::id_size)});
420- start += GraphNode::id_size + GraphNode::weight_size;
430+ res[query_idx].emplace_back (
431+ *(uint64_t *)(node_buffer + offset + start));
432+ start += GraphNode::id_size;
433+ if (need_weight) {
434+ res_weight[query_idx].emplace_back (
435+ *(float *)(node_buffer + offset + start));
436+ start += GraphNode::weight_size;
437+ }
421438 }
422439 offset += actual_size;
423440 }
@@ -445,6 +462,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
445462 sizeof (uint64_t ) * node_num);
446463 closure->request (request_idx)
447464 ->add_params ((char *)&sample_size, sizeof (int ));
465+ closure->request (request_idx)
466+ ->add_params ((char *)&need_weight, sizeof (bool ));
448467 // PsService_Stub rpc_stub(get_cmd_channel(server_index));
449468 GraphPsService_Stub rpc_stub =
450469 getServiceStub (get_cmd_channel (server_index));
0 commit comments