Skip to content
16 changes: 16 additions & 0 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class ListenAndServOp : public framework::OperatorBase {

// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
Expand Down Expand Up @@ -143,6 +146,9 @@ class ListenAndServOp : public framework::OperatorBase {
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
}
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
Expand All @@ -156,9 +162,19 @@ class ListenAndServOp : public framework::OperatorBase {
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}

// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear();
sparse_vars.clear();
} // while(true)
}

Expand Down
24 changes: 22 additions & 2 deletions paddle/fluid/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ limitations under the License. */

namespace paddle {
namespace operators {
static bool IsVariableInitialized(const framework::Scope& scope,
const std::string& varname) {
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().IsInitialized();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().value().IsInitialized();
} else {
PADDLE_THROW(
"Variable type in send side should be in "
"[LodTensor, SelectedRows]");
}
return false;
}

class SendOp : public framework::OperatorBase {
public:
Expand Down Expand Up @@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase {
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();

for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
if (IsVariableInitialized(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
PADDLE_ENFORCE(rpc_client->Wait());

Expand Down
23 changes: 1 addition & 22 deletions paddle/fluid/operators/split_selected_rows_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable();
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int>({}));
Expand Down Expand Up @@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"SplitSelectedRowsOp must has output Out.");

std::vector<int> height_sections =
ctx->Attrs().Get<std::vector<int>>("height_sections");
int64_t n = ctx->Outputs("Out").size();

std::vector<framework::DDim> outs_dims;
outs_dims.reserve(n);

// make output dims
for (int64_t i = 0; i < n; ++i) {
auto dims = ctx->GetInputDim("X");
if (height_sections.size()) {
PADDLE_ENFORCE_EQ(
height_sections.size(), static_cast<size_t>(n),
"The size of height section should be the same with height"
" section size.");
dims[0] = height_sections[i];
}
outs_dims.push_back(dims);
}
ctx->SetOutputsDim("Out", outs_dims);
}
};

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/split_selected_rows_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {

for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(height_sections[i]);
if (rows_idx.size() > 0) {
auto dims = x->GetCompleteDims();
dims[0] = rows_idx.size();
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/sum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0;
for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i);

if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(context.template device_context<DeviceContext>(), sel_row,
offset, out);
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/v2/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def transpile(self,
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])

# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
Expand Down Expand Up @@ -274,6 +275,7 @@ def _create_vars_from_blocklist(self, program, block_list):
name="%s.block%d" % (varname, i),
psersistable=False,
dtype=orig_var.dtype,
type=orig_var.type,
shape=splited_shape) # flattend splited var
var_mapping[varname].append(var)
return var_mapping
Expand Down Expand Up @@ -335,6 +337,7 @@ def _create_var_for_trainers(self, block, var, trainers):
name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable,
dtype=var.dtype,
type=var.type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we are merging splited SelectedRows using concat_op and it's not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum_op implement the same feature, so I reuse it.

shape=var.shape)
var_list.append(var_each)
return var_list
Expand Down Expand Up @@ -561,6 +564,7 @@ def get_pserver_program(self, endpoint):
persistable=True,
dtype=v.dtype,
shape=v.shape)

# step6
optimize_block = pserver_program.create_block(0)
# step 6.1
Expand Down