Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions paddle/fluid/distributed/ps/service/brpc_ps_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,14 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
uint32_t num = *(const uint32_t *)data;

const float *values = (const float *)(data + sizeof(uint32_t));
if (table->PushDenseParam(values, num) != 0) {
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;

// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed");
}
return 0;
Expand Down Expand Up @@ -330,7 +337,15 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->PushSparseParam(keys, values, num) != 0) {

TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushSparseParam(keys, values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushSparseParam error");
}
return 0;
Expand All @@ -349,7 +364,14 @@ int32_t BrpcPsService::PullGeoParam(Table *table,

std::vector<float> values;
std::vector<uint64_t> ids;
table->PullGeoParam(trainer_id, &values, &ids);

TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.geo_pull_keys = &ids;
table_context.pull_context.geo_pull_values = &values;
table_context.trainer_id = trainer_id;
table->Pull(table_context);
// table->PullGeoParam(trainer_id, &values, &ids);

uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
Expand Down Expand Up @@ -625,7 +647,13 @@ int32_t BrpcPsService::PushGlobalStep(Table *table,
const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id();
if (table->PushDense(values, trainer_id) != 0) {

TableContext context;
context.trainer_id = trainer_id;
context.push_context.push_steps = values;

// if (table->PushDense(values, trainer_id) != 0) {
if (table->Push(context) != 0) {
set_response_code(response, -1, "run_program failed");
}

Expand Down
60 changes: 54 additions & 6 deletions paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@ ::std::future<int32_t> PsLocalClient::PullDense(Region* regions,

std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
table_ptr->PullDense(region_buffer.data(), region_buffer.size());

TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = region_buffer.data();
table_context.num = region_buffer.size();
table_ptr->Pull(table_context);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());

size_t region_idx = 0;
size_t region_data_idx = 0;
Expand Down Expand Up @@ -154,6 +160,13 @@ ::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
offset += data_num;
}

TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.push_context.is_param = true;
table_context.num = region_buffer.size();

table_ptr->Push(table_context);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());

return done();
Expand All @@ -168,7 +181,13 @@ ::std::future<int32_t> PsLocalClient::PushDenseRawGradient(

auto* table_ptr = GetTable(table_id);

table_ptr->PushDense(total_send_data, total_send_data_size);
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = total_send_data;
table_context.num = total_send_data_size;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);

delete closure;
return done();
}
Expand All @@ -194,7 +213,12 @@ ::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
offset += data_num;
}

table_ptr->PushDense(region_buffer.data(), region_buffer.size());
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.num = region_buffer.size();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);

return done();
}
Expand Down Expand Up @@ -241,7 +265,15 @@ ::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
//将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = GetTable(table_id);

table_ptr->PullSparsePtr(select_values, keys, num);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = keys;
table_context.pull_context.ptr_values = select_values;
table_context.use_ptr = true;
table_context.num = num;

// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr->Pull(table_context);

return done();
}
Expand All @@ -253,7 +285,15 @@ ::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);

table_ptr->PushSparse(keys, update_values, num);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;

// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
delete closure;
return done();
}
Expand All @@ -265,7 +305,15 @@ ::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);

table_ptr->PushSparse(keys, update_values, num);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;

// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
return done();
}
}
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/distributed/ps/table/common_dense_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,11 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) {
const float* values = context.push_context.values;
return PushDense(values, context.num);
if (!context.push_context.is_param) {
return PushDense(context.push_context.values, context.num);
} else {
return PushDenseParam(context.push_context.values, context.num);
}
}
return 0;
}
Expand Down
22 changes: 12 additions & 10 deletions paddle/fluid/distributed/ps/table/common_dense_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,22 @@ namespace distributed {

class DenseOptimizer;

class CommonDenseTable : public DenseTable {
class CommonDenseTable : public Table {
public:
CommonDenseTable() {}
virtual ~CommonDenseTable() {}
int32_t Initialize() override;
int32_t InitializeShard() override { return 0; }
virtual void CreateInitializer(const std::string& attr,
const std::string& name);
virtual int32_t InitializeValue();
virtual int32_t InitializeOptimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t PullDense(float* pull_values, size_t num) override;
int32_t PushDenseParam(const float* values, size_t num) override;
int32_t PushDense(const float* values, size_t num) override;
void CreateInitializer(const std::string& attr, const std::string& name);
int32_t InitializeValue();
int32_t InitializeOptimizer();

int32_t Pull(TableContext& context) override;
int32_t Push(TableContext& context) override;

int32_t PullDense(float* pull_values, size_t num);
int32_t PushDenseParam(const float* values, size_t num);
int32_t PushDense(const float* values, size_t num);
int32_t Pour() override;
int32_t SetGlobalLR(float* lr) override;

Expand All @@ -54,6 +55,7 @@ class CommonDenseTable : public DenseTable {
int32_t Flush() override { return 0; }
int32_t Shrink(const std::string& param) override { return 0; }
void Clear() override { return; }
void* GetShard(size_t shard_idx) override { return 0; }

protected:
int32_t _PushDense(const float* values, size_t num);
Expand Down
28 changes: 18 additions & 10 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class GraphSampler {
};
#endif

class GraphTable : public SparseTable {
class GraphTable : public Table {
public:
GraphTable() {
use_cache = false;
Expand All @@ -415,6 +415,23 @@ class GraphTable : public SparseTable {
rw_lock.reset(new pthread_rwlock_t());
}
virtual ~GraphTable();

virtual void *GetShard(size_t shard_idx) { return 0; }

static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}

static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}

virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature,
Expand Down Expand Up @@ -452,15 +469,6 @@ class GraphTable : public SparseTable {
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }

virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) {
return 0;
}

virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}

virtual int32_t clear_nodes();
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
Expand Down
14 changes: 8 additions & 6 deletions paddle/fluid/distributed/ps/table/common_sparse_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,16 @@ struct Meta {
}
};

class CommonSparseTable : public SparseTable {
class CommonSparseTable : public Table {
public:
CommonSparseTable() { rwlock_.reset(new phi::RWLock); }
virtual ~CommonSparseTable() {}

// unused method begin
virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
// virtual int32_t PushDenseParam(const float* values, size_t num) { return
// 0; }
// virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// unused method end

virtual int32_t Pull(TableContext& context);
Expand Down Expand Up @@ -163,14 +164,15 @@ class CommonSparseTable : public SparseTable {
// only for sparse geo table
virtual int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num);

virtual int32_t SetGlobalLR(float* lr) override;
virtual int32_t SetGlobalLR(float* lr);

virtual int32_t Pour();
virtual int32_t Flush();
virtual int32_t Shrink(const std::string& param);
virtual void Clear();

virtual void* GetShard(size_t shard_idx) { return 0; }

protected:
virtual int32_t _PushSparse(const uint64_t* keys, const float* values,
size_t num);
Expand Down
57 changes: 0 additions & 57 deletions paddle/fluid/distributed/ps/table/common_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,50 +66,6 @@ struct ReservoirValue {
}
};

class SparseTable : public Table {
public:
SparseTable() {}
virtual ~SparseTable() {}

virtual void *GetShard(size_t shard_idx) { return 0; }

int32_t PullDense(float *values, size_t num) override { return 0; }

int32_t PushDense(const float *values, size_t num) override { return 0; }

static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}

static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
};

class DenseTable : public Table {
public:
DenseTable() {}
virtual ~DenseTable() {}

virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
};

class BarrierTable : public Table {
public:
BarrierTable() {}
Expand All @@ -120,19 +76,6 @@ class BarrierTable : public Table {
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }

int32_t PullDense(float *values, size_t num) override { return 0; }

int32_t PushDense(const float *values, size_t num) override { return 0; }

int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
Expand Down
Loading