Skip to content

Commit 6877d97

Browse files
committed
graph-engine data transfer optimization
1 parent 2cc00bd commit 6877d97

File tree

8 files changed

+103
-62
lines changed

8 files changed

+103
-62
lines changed

paddle/fluid/distributed/service/graph_brpc_client.cc

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,15 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
304304
// char* &buffer,int &actual_size
305305
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
306306
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
307-
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
307+
// std::vector<std::vector<std::pair<uint64_t, float>>> &res,
308+
std::vector<std::vector<uint64_t>> &res,
309+
std::vector<std::vector<float>> &res_weight, bool need_weight,
308310
int server_index) {
309311
if (server_index != -1) {
310312
res.resize(node_ids.size());
313+
if (need_weight) {
314+
res_weight.resize(node_ids.size());
315+
}
311316
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
312317
int ret = 0;
313318
auto *closure = (DownpourBrpcClosure *)done;
@@ -331,11 +336,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
331336
int actual_size = actual_sizes[node_idx];
332337
int start = 0;
333338
while (start < actual_size) {
334-
res[node_idx].push_back(
335-
{*(uint64_t *)(node_buffer + offset + start),
336-
*(float *)(node_buffer + offset + start +
337-
GraphNode::id_size)});
338-
start += GraphNode::id_size + GraphNode::weight_size;
339+
res[node_idx].emplace_back(
340+
*(uint64_t *)(node_buffer + offset + start));
341+
start += GraphNode::id_size;
342+
if (need_weight) {
343+
res_weight[node_idx].emplace_back(
344+
*(float *)(node_buffer + offset + start));
345+
start += GraphNode::weight_size;
346+
}
339347
}
340348
offset += actual_size;
341349
}
@@ -352,6 +360,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
352360
closure->request(0)->add_params((char *)node_ids.data(),
353361
sizeof(uint64_t) * node_ids.size());
354362
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
363+
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
355364
;
356365
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
357366
GraphPsService_Stub rpc_stub =
@@ -364,13 +373,18 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
364373
std::vector<int> request2server;
365374
std::vector<int> server2request(server_size, -1);
366375
res.clear();
376+
res_weight.clear();
367377
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
368378
int server_index = get_server_index_by_id(node_ids[query_idx]);
369379
if (server2request[server_index] == -1) {
370380
server2request[server_index] = request2server.size();
371381
request2server.push_back(server_index);
372382
}
373-
res.push_back(std::vector<std::pair<uint64_t, float>>());
383+
// res.push_back(std::vector<std::pair<uint64_t, float>>());
384+
res.push_back({});
385+
if (need_weight) {
386+
res_weight.push_back({});
387+
}
374388
}
375389
size_t request_call_num = request2server.size();
376390
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
@@ -413,11 +427,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
413427
int actual_size = actual_sizes[node_idx];
414428
int start = 0;
415429
while (start < actual_size) {
416-
res[query_idx].push_back(
417-
{*(uint64_t *)(node_buffer + offset + start),
418-
*(float *)(node_buffer + offset + start +
419-
GraphNode::id_size)});
420-
start += GraphNode::id_size + GraphNode::weight_size;
430+
res[query_idx].emplace_back(
431+
*(uint64_t *)(node_buffer + offset + start));
432+
start += GraphNode::id_size;
433+
if (need_weight) {
434+
res_weight[query_idx].emplace_back(
435+
*(float *)(node_buffer + offset + start));
436+
start += GraphNode::weight_size;
437+
}
421438
}
422439
offset += actual_size;
423440
}
@@ -445,6 +462,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
445462
sizeof(uint64_t) * node_num);
446463
closure->request(request_idx)
447464
->add_params((char *)&sample_size, sizeof(int));
465+
closure->request(request_idx)
466+
->add_params((char *)&need_weight, sizeof(bool));
448467
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
449468
GraphPsService_Stub rpc_stub =
450469
getServiceStub(get_cmd_channel(server_index));

paddle/fluid/distributed/service/graph_brpc_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
6464
// given a batch of nodes, sample graph_neighbors for each of them
6565
virtual std::future<int32_t> batch_sample_neighbors(
6666
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
67-
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
67+
std::vector<std::vector<uint64_t>>& res,
68+
std::vector<std::vector<float>>& res_weight, bool need_weight,
6869
int server_index = -1);
6970

7071
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,19 +378,21 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
378378
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
379379
brpc::Controller *cntl) {
380380
CHECK_TABLE_EXIST(table, request, response)
381-
if (request.params_size() < 2) {
381+
if (request.params_size() < 3) {
382382
set_response_code(
383383
response, -1,
384-
"graph_random_sample request requires at least 2 arguments");
384+
"graph_random_sample_neighbors request requires at least 3 arguments");
385385
return 0;
386386
}
387387
size_t node_num = request.params(0).size() / sizeof(uint64_t);
388388
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
389389
int sample_size = *(uint64_t *)(request.params(1).c_str());
390+
bool need_weight = *(bool *)(request.params(2).c_str());
390391
std::vector<std::shared_ptr<char>> buffers(node_num);
391392
std::vector<int> actual_sizes(node_num, 0);
392393
((GraphTable *)table)
393-
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes);
394+
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes,
395+
need_weight);
394396

395397
cntl->response_attachment().append(&node_num, sizeof(size_t));
396398
cntl->response_attachment().append(actual_sizes.data(),
@@ -454,16 +456,17 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
454456
brpc::Controller *cntl) {
455457
// sleep(5);
456458
CHECK_TABLE_EXIST(table, request, response)
457-
if (request.params_size() < 2) {
458-
set_response_code(
459-
response, -1,
460-
"graph_random_neighbors_sample request requires at least 2 arguments");
459+
if (request.params_size() < 3) {
460+
set_response_code(response, -1,
461+
"sample_neighbors_across_multi_servers request requires "
462+
"at least 3 arguments");
461463
return 0;
462464
}
463465
size_t node_num = request.params(0).size() / sizeof(uint64_t),
464466
size_of_size_t = sizeof(size_t);
465467
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
466468
int sample_size = *(uint64_t *)(request.params(1).c_str());
469+
bool need_weight = *(uint64_t *)(request.params(2).c_str());
467470
// std::vector<uint64_t> res = ((GraphTable
468471
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
469472
std::vector<int> request2server;
@@ -581,6 +584,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
581584
sizeof(uint64_t) * node_num);
582585
closure->request(request_idx)
583586
->add_params((char *)&sample_size, sizeof(int));
587+
closure->request(request_idx)
588+
->add_params((char *)&need_weight, sizeof(bool));
584589
PsService_Stub rpc_stub(
585590
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
586591
// GraphPsService_Stub rpc_stub =
@@ -592,7 +597,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
592597
if (server2request[rank] != -1) {
593598
((GraphTable *)table)
594599
->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
595-
local_buffers, local_actual_sizes);
600+
local_buffers, local_actual_sizes,
601+
need_weight);
596602
}
597603
local_promise.get()->set_value(0);
598604
if (remote_call_num == 0) func(closure);

paddle/fluid/distributed/service/graph_py_service.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,13 @@ GraphPyClient::batch_sample_neighbors(std::string name,
295295
std::vector<uint64_t> node_ids,
296296
int sample_size, bool return_weight,
297297
bool return_edges) {
298-
std::vector<std::vector<std::pair<uint64_t, float>>> v;
298+
// std::vector<std::vector<std::pair<uint64_t, float>>> v;
299+
std::vector<std::vector<uint64_t>> v;
300+
std::vector<std::vector<float>> v1;
299301
if (this->table_id_map.count(name)) {
300302
uint32_t table_id = this->table_id_map[name];
301-
auto status =
302-
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
303+
auto status = worker_ptr->batch_sample_neighbors(
304+
table_id, node_ids, sample_size, v, v1, return_weight);
303305
status.wait();
304306
}
305307

@@ -313,9 +315,10 @@ GraphPyClient::batch_sample_neighbors(std::string name,
313315
if (return_edges) res.first.push_back({});
314316
for (size_t i = 0; i < v.size(); i++) {
315317
for (size_t j = 0; j < v[i].size(); j++) {
316-
res.first[0].push_back(v[i][j].first);
318+
// res.first[0].push_back(v[i][j].first);
319+
res.first[0].push_back(v[i][j]);
317320
if (return_edges) res.first[2].push_back(node_ids[i]);
318-
if (return_weight) res.second.push_back(v[i][j].second);
321+
if (return_weight) res.second.push_back(v1[i][j]);
319322
}
320323
if (i == v.size() - 1) break;
321324

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,8 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
396396
}
397397
int32_t GraphTable::random_sample_neighbors(
398398
uint64_t *node_ids, int sample_size,
399-
std::vector<std::shared_ptr<char>> &buffers,
400-
std::vector<int> &actual_sizes) {
399+
std::vector<std::shared_ptr<char>> &buffers, std::vector<int> &actual_sizes,
400+
bool need_weight) {
401401
size_t node_num = buffers.size();
402402
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
403403
std::vector<std::future<int>> tasks;
@@ -407,7 +407,7 @@ int32_t GraphTable::random_sample_neighbors(
407407
for (size_t idx = 0; idx < node_num; ++idx) {
408408
index = get_thread_pool_index(node_ids[idx]);
409409
seq_id[index].emplace_back(idx);
410-
id_list[index].emplace_back(node_ids[idx], sample_size);
410+
id_list[index].emplace_back(node_ids[idx], sample_size, need_weight);
411411
}
412412
for (int i = 0; i < seq_id.size(); i++) {
413413
if (seq_id[i].size() == 0) continue;
@@ -442,25 +442,29 @@ int32_t GraphTable::random_sample_neighbors(
442442
}
443443
std::shared_ptr<char> &buffer = buffers[idx];
444444
std::vector<int> res = node->sample_k(sample_size, rng);
445-
actual_size = res.size() * (Node::id_size + Node::weight_size);
445+
actual_size =
446+
res.size() * (need_weight ? (Node::id_size + Node::weight_size)
447+
: Node::id_size);
446448
int offset = 0;
447449
uint64_t id;
448450
float weight;
449451
char *buffer_addr = new char[actual_size];
450452
if (response == LRUResponse::ok) {
451-
sample_keys.emplace_back(node_id, sample_size);
453+
sample_keys.emplace_back(node_id, sample_size, need_weight);
452454
sample_res.emplace_back(actual_size, buffer_addr);
453455
buffer = sample_res.back().buffer;
454456
} else {
455457
buffer.reset(buffer_addr, char_del);
456458
}
457459
for (int &x : res) {
458460
id = node->get_neighbor_id(x);
459-
weight = node->get_neighbor_weight(x);
460461
memcpy(buffer_addr + offset, &id, Node::id_size);
461462
offset += Node::id_size;
462-
memcpy(buffer_addr + offset, &weight, Node::weight_size);
463-
offset += Node::weight_size;
463+
if (need_weight) {
464+
weight = node->get_neighbor_weight(x);
465+
memcpy(buffer_addr + offset, &weight, Node::weight_size);
466+
offset += Node::weight_size;
467+
}
464468
}
465469
}
466470
}

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,14 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 };
8080
struct SampleKey {
8181
uint64_t node_key;
8282
size_t sample_size;
83-
SampleKey(uint64_t _node_key, size_t _sample_size)
84-
: node_key(_node_key), sample_size(_sample_size) {}
83+
bool is_weighted;
84+
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted)
85+
: node_key(_node_key),
86+
sample_size(_sample_size),
87+
is_weighted(_is_weighted) {}
8588
bool operator==(const SampleKey &s) const {
86-
return node_key == s.node_key && sample_size == s.sample_size;
89+
return node_key == s.node_key && sample_size == s.sample_size &&
90+
is_weighted == s.is_weighted;
8791
}
8892
};
8993

@@ -360,7 +364,7 @@ class GraphTable : public SparseTable {
360364
virtual int32_t random_sample_neighbors(
361365
uint64_t *node_ids, int sample_size,
362366
std::vector<std::shared_ptr<char>> &buffers,
363-
std::vector<int> &actual_sizes);
367+
std::vector<int> &actual_sizes, bool need_weight);
364368

365369
int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
366370
int &actual_sizes);

paddle/fluid/distributed/table/graph/graph_node.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Node {
5151

5252
protected:
5353
uint64_t id;
54+
bool is_weighted;
5455
};
5556

5657
class GraphNode : public Node {

0 commit comments

Comments
 (0)