Skip to content

Commit 2e48442

Browse files
committed
cache optimization
1 parent 1325315 commit 2e48442

File tree

2 files changed

+105
-123
lines changed

2 files changed

+105
-123
lines changed

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 104 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ class LRUNode {
105105
LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
106106
next = pre = NULL;
107107
}
108-
std::chrono::milliseconds ms;
109-
// the last hit time
110108
K key;
111109
V data;
112110
size_t ttl;
@@ -119,12 +117,13 @@ class ScaledLRU;
119117
template <typename K, typename V>
120118
class RandomSampleLRU {
121119
public:
122-
RandomSampleLRU(ScaledLRU<K, V> *_father) : father(_father) {
120+
RandomSampleLRU(ScaledLRU<K, V> *_father) {
121+
father = _father;
122+
remove_count = 0;
123123
node_size = 0;
124124
node_head = node_end = NULL;
125125
global_ttl = father->ttl;
126-
extra_penalty = 0;
127-
size_limit = (father->size_limit / father->shard_num + 1);
126+
total_diff = 0;
128127
}
129128

130129
~RandomSampleLRU() {
@@ -138,63 +137,71 @@ class RandomSampleLRU {
138137
LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
139138
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
140139
return LRUResponse::blocked;
141-
int init_node_size = node_size;
142-
try {
143-
// pthread_rwlock_rdlock(&father->rwlock);
144-
for (size_t i = 0; i < length; i++) {
145-
auto iter = key_map.find(keys[i]);
146-
if (iter != key_map.end()) {
147-
res.emplace_back(keys[i], iter->second->data);
148-
iter->second->ttl--;
149-
if (iter->second->ttl == 0) {
150-
remove(iter->second);
151-
} else {
152-
move_to_tail(iter->second);
153-
}
140+
// pthread_rwlock_rdlock(&father->rwlock);
141+
int init_size = node_size - remove_count;
142+
process_redundant(length * 3);
143+
144+
for (size_t i = 0; i < length; i++) {
145+
auto iter = key_map.find(keys[i]);
146+
if (iter != key_map.end()) {
147+
res.emplace_back(keys[i], iter->second->data);
148+
iter->second->ttl--;
149+
if (iter->second->ttl == 0) {
150+
remove(iter->second);
151+
if (remove_count != 0) remove_count--;
152+
} else {
153+
move_to_tail(iter->second);
154154
}
155155
}
156-
} catch (...) {
157-
pthread_rwlock_unlock(&father->rwlock);
158-
father->handle_size_diff(node_size - init_node_size);
159-
return LRUResponse::err;
156+
}
157+
total_diff += node_size - remove_count - init_size;
158+
if (total_diff >= 500 || total_diff < -500) {
159+
father->handle_size_diff(total_diff);
160+
total_diff = 0;
160161
}
161162
pthread_rwlock_unlock(&father->rwlock);
162-
father->handle_size_diff(node_size - init_node_size);
163163
return LRUResponse::ok;
164164
}
165165
LRUResponse insert(K *keys, V *data, size_t length) {
166166
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
167167
return LRUResponse::blocked;
168-
int init_node_size = node_size;
169-
try {
170-
for (size_t i = 0; i < length; i++) {
171-
auto iter = key_map.find(keys[i]);
172-
if (iter != key_map.end()) {
173-
move_to_tail(iter->second);
174-
iter->second->ttl = global_ttl;
175-
iter->second->data = data[i];
176-
} else {
177-
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
178-
add_new(temp);
179-
}
168+
// pthread_rwlock_rdlock(&father->rwlock);
169+
int init_size = node_size - remove_count;
170+
process_redundant(length * 3);
171+
for (size_t i = 0; i < length; i++) {
172+
auto iter = key_map.find(keys[i]);
173+
if (iter != key_map.end()) {
174+
move_to_tail(iter->second);
175+
iter->second->ttl = global_ttl;
176+
iter->second->data = data[i];
177+
} else {
178+
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
179+
add_new(temp);
180180
}
181-
} catch (...) {
182-
pthread_rwlock_unlock(&father->rwlock);
183-
father->handle_size_diff(node_size - init_node_size);
184-
return LRUResponse::err;
185181
}
182+
total_diff += node_size - remove_count - init_size;
183+
if (total_diff >= 500 || total_diff < -500) {
184+
father->handle_size_diff(total_diff);
185+
total_diff = 0;
186+
}
187+
186188
pthread_rwlock_unlock(&father->rwlock);
187-
father->handle_size_diff(node_size - init_node_size);
188189
return LRUResponse::ok;
189190
}
190191
void remove(LRUNode<K, V> *node) {
191192
fetch(node);
192193
node_size--;
193194
key_map.erase(node->key);
194195
delete node;
195-
if (node_size >= size_limit) {
196-
extra_penalty -= 1.0;
196+
}
197+
198+
void process_redundant(int process_size) {
199+
size_t length = std::min(remove_count, process_size);
200+
while (length--) {
201+
remove(node_head);
202+
remove_count--;
197203
}
204+
// std::cerr<<"after remove_count = "<<remove_count<<std::endl;
198205
}
199206

200207
void move_to_tail(LRUNode<K, V> *node) {
@@ -207,12 +214,6 @@ class RandomSampleLRU {
207214
place_at_tail(node);
208215
node_size++;
209216
key_map[node->key] = node;
210-
if (node_size > size_limit) {
211-
extra_penalty += penalty_inc;
212-
if (extra_penalty >= 1.0) {
213-
remove(node_head);
214-
}
215-
}
216217
}
217218
void place_at_tail(LRUNode<K, V> *node) {
218219
if (node_end == NULL) {
@@ -224,8 +225,6 @@ class RandomSampleLRU {
224225
node->next = NULL;
225226
node_end = node;
226227
}
227-
node->ms = std::chrono::duration_cast<std::chrono::milliseconds>(
228-
std::chrono::system_clock::now().time_since_epoch());
229228
}
230229

231230
void fetch(LRUNode<K, V> *node) {
@@ -245,11 +244,10 @@ class RandomSampleLRU {
245244
std::unordered_map<K, LRUNode<K, V> *> key_map;
246245
ScaledLRU<K, V> *father;
247246
size_t global_ttl, size_limit;
248-
int node_size;
247+
int node_size, total_diff;
249248
LRUNode<K, V> *node_head, *node_end;
250249
friend class ScaledLRU<K, V>;
251-
float extra_penalty;
252-
const float penalty_inc = 0.75;
250+
int remove_count;
253251
};
254252

255253
template <typename K, typename V>
@@ -268,7 +266,7 @@ class ScaledLRU {
268266
while (true) {
269267
{
270268
std::unique_lock<std::mutex> lock(mutex_);
271-
cv_.wait_for(lock, std::chrono::milliseconds(20000));
269+
cv_.wait_for(lock, std::chrono::milliseconds(3000));
272270
if (stop) {
273271
return;
274272
}
@@ -295,52 +293,33 @@ class ScaledLRU {
295293
int shrink() {
296294
int node_size = 0;
297295
for (size_t i = 0; i < lru_pool.size(); i++) {
298-
node_size += lru_pool[i].node_size;
296+
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
299297
}
300298

301-
if (node_size <= 1.2 * size_limit) return 0;
299+
if (node_size <= size_t(1.1 * size_limit) + 1) return 0;
302300
if (pthread_rwlock_wrlock(&rwlock) == 0) {
303-
try {
304-
global_count = 0;
305-
std::priority_queue<RemovedNode, std::vector<RemovedNode>,
306-
std::greater<RemovedNode>>
307-
q;
308-
for (size_t i = 0; i < lru_pool.size(); i++) {
309-
if (lru_pool[i].node_size > 0) {
310-
global_count += lru_pool[i].node_size;
311-
q.push({lru_pool[i].node_head, &lru_pool[i]});
312-
}
313-
}
314-
if (global_count > size_limit) {
315-
// VLOG(0)<<"before shrinking cache, cached nodes count =
316-
// "<<global_count<<std::endl;
317-
size_t remove = global_count - size_limit;
318-
while (remove--) {
319-
RemovedNode remove_node = q.top();
320-
q.pop();
321-
auto next = remove_node.node->next;
322-
if (next) {
323-
q.push({next, remove_node.lru_pointer});
324-
}
325-
global_count--;
326-
remove_node.lru_pointer->remove(remove_node.node);
327-
}
328-
for (size_t i = 0; i < lru_pool.size(); i++) {
329-
lru_pool[i].size_limit = lru_pool[i].node_size;
330-
lru_pool[i].extra_penalty = 0;
331-
}
332-
// VLOG(0)<<"after shrinking cache, cached nodes count =
333-
// // "<<global_count<<std::endl;
301+
// std::cerr<<"in shrink\n";
302+
global_count = 0;
303+
for (size_t i = 0; i < lru_pool.size(); i++) {
304+
global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
305+
}
306+
// std::cerr<<"global_count "<<global_count<<"\n";
307+
if (global_count > size_limit) {
308+
size_t remove = global_count - size_limit;
309+
for (int i = 0; i < lru_pool.size(); i++) {
310+
lru_pool[i].total_diff = 0;
311+
lru_pool[i].remove_count +=
312+
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
313+
global_count * remove;
314+
// std::cerr<<i<<" "<<lru_pool[i].remove_count<<std::endl;
334315
}
335-
} catch (...) {
336-
pthread_rwlock_unlock(&rwlock);
337-
return -1;
338316
}
339317
pthread_rwlock_unlock(&rwlock);
340318
return 0;
341319
}
342320
return 0;
343321
}
322+
344323
void handle_size_diff(int diff) {
345324
if (diff != 0) {
346325
__sync_fetch_and_add(&global_count, diff);
@@ -358,18 +337,13 @@ class ScaledLRU {
358337
pthread_rwlock_t rwlock;
359338
size_t shard_num;
360339
int global_count;
361-
size_t size_limit;
340+
size_t size_limit, total, hit;
362341
size_t ttl;
363342
bool stop;
364343
std::thread shrink_job;
365344
std::vector<RandomSampleLRU<K, V>> lru_pool;
366345
mutable std::mutex mutex_;
367346
std::condition_variable cv_;
368-
struct RemovedNode {
369-
LRUNode<K, V> *node;
370-
RandomSampleLRU<K, V> *lru_pointer;
371-
bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; }
372-
};
373347
std::shared_ptr<::ThreadPool> thread_pool;
374348
friend class RandomSampleLRU<K, V>;
375349
};
@@ -448,13 +422,46 @@ class GraphTable : public SparseTable {
448422
std::unique_lock<std::mutex> lock(mutex_);
449423
if (use_cache == false) {
450424
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
451-
shard_end - shard_start, size_limit, ttl));
425+
task_pool_size_, size_limit, ttl));
452426
use_cache = true;
453427
}
454428
}
455429
return 0;
456430
}
457431

432+
virtual int32_t test_sample_with_cache(int size, int batch_size,
433+
int sample_size) {
434+
std::vector<int> actual_sizes1, actual_sizes2;
435+
std::vector<std::shared_ptr<char>> buffers1, buffers2;
436+
std::vector<uint64_t> node_ids1(batch_size), node_ids2;
437+
for (int i = 0; i <= size - batch_size; i += batch_size) {
438+
for (int j = 0; j < batch_size; j++) {
439+
node_ids1[j] = i + j;
440+
}
441+
actual_sizes1.resize(batch_size);
442+
buffers1.resize(batch_size);
443+
random_sample_neighbors(node_ids1.data(), sample_size, buffers1,
444+
actual_sizes1);
445+
node_ids2.clear();
446+
for (int j = 0; j < batch_size; j++) {
447+
if (actual_sizes1[j] != 0) {
448+
int offset = 0;
449+
char *p = buffers1[j].get();
450+
while (offset < actual_sizes1[j]) {
451+
node_ids2.push_back(*(uint64_t *)(p + offset));
452+
offset += Node::id_size + Node::weight_size;
453+
}
454+
}
455+
}
456+
buffers2.resize(node_ids2.size());
457+
actual_sizes2.resize(node_ids2.size());
458+
random_sample_neighbors(node_ids2.data(), sample_size, buffers2,
459+
actual_sizes2);
460+
}
461+
462+
return 0;
463+
}
464+
458465
protected:
459466
std::vector<GraphShard> shards;
460467
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;

paddle/fluid/distributed/test/graph_node_test.cc

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -681,28 +678,6 @@ void testCache() {
681678
}
682679
st.query(0, &skey, 1, r);
683680
ASSERT_EQ((int)r.size(), 0);
684-
::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
685-
::paddle::distributed::SampleResult>
686-
cache1(2, 1, 4);
687-
str = new char[18];
688-
strcpy(str, "3433776521");
689-
result = new ::paddle::distributed::SampleResult(strlen(str), str);
690-
cache1.insert(1, &skey, result, 1);
691-
::paddle::distributed::SampleKey skey1 = {8, 1};
692-
char* str1 = new char[18];
693-
strcpy(str1, "3xcf2eersfd");
694-
usleep(3000); // sleep 3ms to guaruntee that skey1's time stamp is larger
695-
// than skey;
696-
auto result1 = new ::paddle::distributed::SampleResult(strlen(str1), str1);
697-
cache1.insert(0, &skey1, result1, 1);
698-
sleep(1); // sleep 1 s to guarantee that shrinking work is done
699-
cache1.query(1, &skey, 1, r);
700-
ASSERT_EQ((int)r.size(), 0);
701-
cache1.query(0, &skey1, 1, r);
702-
ASSERT_EQ((int)r.size(), 1);
703-
char* p1 = (char*)r[0].second.buffer.get();
704-
for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p1[j], str1[j]);
705-
r.clear();
706681
}
707682
void testGraphToBuffer() {
708683
::paddle::distributed::GraphNode s, s1;
@@ -718,4 +693,4 @@ void testGraphToBuffer() {
718693
VLOG(0) << s1.get_feature(0);
719694
}
720695

721-
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
696+
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }

0 commit comments

Comments
 (0)