Skip to content
19 changes: 19 additions & 0 deletions paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end();
}

VarDesc *BlockDesc::RenameVar(const std::string &old_name,
const std::string &new_name) {
if (!this->HasVar(old_name)) {
return nullptr;
}
need_update_ = true;
auto *var = this->Var(old_name);
VarDesc *new_var = new VarDesc(*(var->Proto()));
new_var->SetName(new_name);
vars_[new_name].reset(new_var);
// rename inputs and outputs
for (const auto &op : ops_) {
auto *it = op.get();
it->Rename(old_name, new_name);
}
vars_.erase(old_name);
return new_var;
}

VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr;

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class BlockDesc {

bool HasVar(const std::string &var_name) const;

VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);

VarDesc *FindVarRecursive(const std::string &name_bytes) const;

VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
Expand Down
60 changes: 12 additions & 48 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
server_thread_->join();
}

std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}

void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
Expand All @@ -91,8 +84,7 @@ class ListenAndServOp : public framework::OperatorBase {
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin");

auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
Expand All @@ -109,40 +101,24 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
// receive a variable
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto it =
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
update_param_cnt++;
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
} else {
VLOG(3) << "received variable: " << grad_var_name
<< " no need to update param";
}
if (fan_in > 1 && !param_var_name.empty()) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
auto *var = recv_scope.FindVar(recv_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
Expand All @@ -151,43 +127,40 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
if (exit_flag) {
rpc_service_->ShutDown();
}
VLOG(3) << "run optimize graph...";
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}

// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear();
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size());
sparse_vars.clear();
} // while(true)
}

protected:
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
std::shared_ptr<std::thread> server_thread_;
mutable std::unordered_map<std::string, int> grads_counter_;
};

class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator

Expand All @@ -201,16 +174,7 @@ from send_op and send back variables to recv_op.
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<std::vector<std::string>>(
"ParamList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"GradList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({});
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
};
Expand Down
10 changes: 9 additions & 1 deletion paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ void BindBlockDesc(py::module &m) {
[](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name;
return self.HasVar(name);
},
py::return_value_policy::reference)
.def("rename_var",
[](BlockDesc &self, const py::bytes &byte_name,
const py::bytes &byte_name_new) {
std::string name = byte_name;
std::string new_name = byte_name_new;
self.RenameVar(name, new_name);
})
.def("has_var_recursive",
[](BlockDesc &self, py::bytes byte_name) {
Expand Down Expand Up @@ -207,7 +215,7 @@ void BindVarDsec(py::module &m) {
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
.def("name",
[](const VarDesc &self) {
[](VarDesc &self) {
py::bytes name = self.Name();
return name;
},
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/v2/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename):
print "file md5", md5file(filename), md5sum
if retry < retry_limit:
retry += 1
else:
Expand Down
Loading