Skip to content
233 changes: 204 additions & 29 deletions paddle/fluid/distributed/collective/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ void EagerReducer::InitializeGroups(
is_sparse_gradient_[tensor_indices_.front()]) {
// process the sparse gradient. one sparse, one group
group.dtype_ = first_var.dtype();
group.is_sparse_ = true;
} else {
// process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group);
Expand Down Expand Up @@ -391,6 +392,12 @@ void EagerReducer::InitializeDenseGroups(
auto &tensor = tensors_[tensor_index];
auto &tensor_name = tensor.name();

PADDLE_ENFORCE_EQ(is_sparse_gradient_[tensor_index], false,
platform::errors::PreconditionNotMet(
"Tensor %s's GRAD must be Tensor, but received "
"GRAD is SelectedRows",
tensor_name));

PADDLE_ENFORCE_EQ(tensor.is_initialized(), true,
platform::errors::PreconditionNotMet(
"Tensor %s is not initialized.", tensor_name));
Expand Down Expand Up @@ -480,6 +487,7 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) {
group.pending_ = group.tensor_indices_.size();
group.sparse_contents_ = Tensor();
});

// reinitialize vars_marked_ready_ for next iteration
Expand Down Expand Up @@ -544,9 +552,6 @@ void EagerReducer::AddDistHook(size_t var_index) {
return;
}

auto &tensor = tensors_[var_index];
const auto &grad_node = GetGradNodeFromTensor(&tensor);

VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "@Grad] arrived and triggered disthook";

Expand Down Expand Up @@ -608,33 +613,69 @@ void EagerReducer::MarkVarReady(const size_t var_index,
auto &group_tensor = group.dense_tensors_[inside_group_index];
const auto length = group.length_[inside_group_index];

if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
group_tensor
.ShareDataWith(
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()});
} else {
// TODO(shenliang03): maybe save the memory by avoiding tensor construction
if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_);
}
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
if (!group.is_sparse_) {
if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor =
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor->impl())))
.Resize({length});
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
// TODO(shenliang03): maybe save the memory by avoiding tensor
// construction
if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_);
}
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
group_tensor
.ShareDataWith(*(std::dynamic_pointer_cast<phi::DenseTensor>(
grad_tensor->impl())))
.Resize({length});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
}
}
} else {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();

// process sparse group
PADDLE_ENFORCE_EQ(
HasGrad(var_index), true,
platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] should have gradient. "
"Currently, DataParallel does not support sparse "
"parameters without generating gradients during training. "
"For example, if is_sparese=True is used in Embedding, "
"the current step of this parameter cannot generate gradient "
"because of stop_gradient/detatch, where error will occur.",
var_index, tensors_[var_index].name()));

// need to check tensor type
PADDLE_ENFORCE_EQ(
grad_tensor.is_selected_rows(), true,
platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] must have a selectedrows gradient. "
"Before forward pass, the parameter type is inferred to be "
"SelectedRows, but after backward pass, its actual type becomes "
"LodTensor. It is currently not supported by DataParallel. "
"For example, if sparse embedding is used, and the weight of "
"embedding is shared with subsequent dense parameters, then "
"the parameter gradient of the embedding will be converted "
"to dense parameters.",
var_index, tensors_[var_index].name()));

group.sparse_contents_.set_impl(grad_tensor.impl());
}

if (--group.pending_ == 0) {
Expand Down Expand Up @@ -666,7 +707,11 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) {
UNUSED auto &group = groups_[next_group_];
FusedAllReduceSchedule(&group, next_group_);
if (group.is_sparse_) {
AllReduceSparse(&group, next_group_);
} else {
FusedAllReduceSchedule(&group, next_group_);
}
}
}

Expand Down Expand Up @@ -725,6 +770,11 @@ void EagerReducer::ProcessUnusedDenseVars() {
const auto inside_group_index = var_locator.inside_group_index;
auto &src_tensor = group.dense_tensors_[inside_group_index];

// sparse no need to check and no support find_unused_parameters
if (group.is_sparse_) {
continue;
}

Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));

auto dest_var_base = tensors_[var_index];
Expand All @@ -739,11 +789,15 @@ void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) {
group.task->Synchronize();
if (!group.is_sparse_) {
group.task->Synchronize();
}
}

for (auto &group : groups_) {
group.SplitTensors(inner_place_);
if (!group.is_sparse_) {
group.SplitTensors(inner_place_);
}
}

if (find_unused_vars_each_step_) {
Expand Down Expand Up @@ -778,6 +832,127 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
// split in FinalizeBackward()
}

void EagerReducer::AllReduceSparse(EagerGroup *group,
const int curr_group_index) {
// div nranks
Tensor sparse_tensor(group->sparse_contents_);
paddle::experimental::scale_(sparse_tensor, 1.0 / nranks_, 0.0, false);

VLOG(3) << "sparse_group [" << curr_group_index << "] start allreduce.";

auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
if (platform::is_gpu_place(inner_place_)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(inner_place_));
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(inner_place_)) {
dev_ctx = static_cast<platform::CPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(inner_place_));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split grad tensor not supported on place (%s)", inner_place_));
}

auto src = std::dynamic_pointer_cast<phi::SelectedRows>(
group->sparse_contents_.impl());
const auto &src_rows = src->rows();

const auto &rank_ = process_group_->GetRank();
const auto &size_ = process_group_->GetSize();

framework::Vector<int64_t> rows_num_vector(size_);
rows_num_vector[rank_] = static_cast<int64_t>(src_rows.size());

Tensor rows_num_tensor = paddle::experimental::empty(
IntArray({static_cast<int64_t>(size_)}), DataType::INT64, inner_place_);
auto *rows_num_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(rows_num_tensor.impl()).get();
framework::TensorFromVector<int64_t>(rows_num_vector, *dev_ctx,
rows_num_dense_tensor);

distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {rows_num_tensor};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();

framework::TensorToVector<int64_t>(*rows_num_dense_tensor, *dev_ctx,
&rows_num_vector);
dev_ctx->Wait();

const auto *cpu_rows_num_ptr = rows_num_vector.data();
auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
static_cast<int64_t>(0));

VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
<< ", total rows number: " << rows_num
<< ", height: " << src->height();

dev_ctx->Wait();

if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
[&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
// allgather is used to speed up the allreduce by replacing broadcast.

VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";

Tensor dst_rows_tensor =
paddle::experimental::empty(IntArray({static_cast<int64_t>(rows_num)}),
DataType::INT64, inner_place_);
Tensor src_rows_tensor = paddle::experimental::empty(
IntArray({static_cast<int64_t>((*src).rows().size())}), DataType::INT64,
inner_place_);
auto *src_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(src_rows_tensor.impl())
.get();
framework::TensorFromVector<int64_t>((*src).rows(), *dev_ctx,
src_rows_dense_tensor);

std::vector<Tensor> src_rows_tensors = {src_rows_tensor};
std::vector<Tensor> dst_rows_tensors = {dst_rows_tensor};
process_group_->AllGather(src_rows_tensors, dst_rows_tensors)
->Synchronize();

framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
auto *dst_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(dst_rows_tensor.impl())
.get();
framework::TensorToVector<int64_t>(*dst_rows_dense_tensor, *dev_ctx,
&dst_rows_vector);
dev_ctx->Wait();

Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();
dst_shape[dst_shape.size() - 2] = rows_num;
auto dst_dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
paddle::experimental::full(IntArray(dst_shape), 0,
src_value_tensor.dtype(), inner_place_)
.impl());

auto dst =
std::make_shared<phi::SelectedRows>(dst_rows_vector, (*src).height());
*(dst->mutable_value()) = *dst_dense_tensor;
Tensor dst_value_tensor(std::make_shared<phi::DenseTensor>(dst->value()));

std::vector<Tensor> src_value_tensors = {src_value_tensor};
std::vector<Tensor> dst_value_tensors = {dst_value_tensor};
process_group_->AllGather(src_value_tensors, dst_value_tensors)
->Synchronize();

src->set_rows(dst_rows_vector);
*(src->mutable_value()) =
*(std::dynamic_pointer_cast<phi::DenseTensor>(dst_value_tensor.impl()));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("This case is not supported."));
}
}

std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
const auto &tensors_ = group.tensor_indices_;
out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size()
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/collective/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
class EagerGroup {
public:
Tensor dense_contents_;
Tensor sparse_contents_;
bool is_sparse_ = false;

// for concat kernel
std::vector<phi::DenseTensor> dense_tensors_;
Expand Down Expand Up @@ -104,6 +106,7 @@ class EagerReducer {
void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index);
void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index);
void AllReduceSparse(EagerGroup *group, const int curr_group_index);
void FinalizeBackward();
void TraverseBackwardGraph(const std::vector<Tensor> &outputs);
void ProcessUnusedDenseVars();
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150)
Expand All @@ -1152,8 +1152,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300)

if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height PROPERTIES TIMEOUT 150)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 150)
endif()
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self,
dtype=dtype,
is_sparse=is_sparse,
param_attr=fluid.ParamAttr(
name='embedding_param',
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter(
Expand Down Expand Up @@ -103,8 +102,8 @@ def get_model(self):
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)

optimizer = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters())
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())

return model, train_reader, optimizer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self,
self.hidden_size,
sparse=True,
weight_attr=paddle.ParamAttr(
name='embedding_param',
initializer=paddle.nn.initializer.Uniform(
low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(self,
self.hidden_size,
sparse=is_sparse,
weight_attr=paddle.ParamAttr(
name='embedding_param',
initializer=paddle.nn.initializer.Uniform(
low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter(
Expand Down
Loading