Skip to content

Commit 1b03198

Browse files
authored
[Dygraph] Support sparse tensor in refactored reducer (#40836)
* [Dygraph] Support sparse tensor in refactored reducer * add uts * refactor * update * fix bugs
1 parent 625dd72 commit 1b03198

14 files changed

+440
-37
lines changed

paddle/fluid/distributed/collective/reducer.cc

Lines changed: 204 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ void EagerReducer::InitializeGroups(
360360
is_sparse_gradient_[tensor_indices_.front()]) {
361361
// process the sparse gradient. one sparse, one group
362362
group.dtype_ = first_var.dtype();
363+
group.is_sparse_ = true;
363364
} else {
364365
// process the dense gradient.
365366
InitializeDenseGroups(tensor_indices_, &group);
@@ -391,6 +392,12 @@ void EagerReducer::InitializeDenseGroups(
391392
auto &tensor = tensors_[tensor_index];
392393
auto &tensor_name = tensor.name();
393394

395+
PADDLE_ENFORCE_EQ(is_sparse_gradient_[tensor_index], false,
396+
platform::errors::PreconditionNotMet(
397+
"Tensor %s's GRAD must be Tensor, but received "
398+
"GRAD is SelectedRows",
399+
tensor_name));
400+
394401
PADDLE_ENFORCE_EQ(tensor.is_initialized(), true,
395402
platform::errors::PreconditionNotMet(
396403
"Tensor %s is not initialized.", tensor_name));
@@ -480,6 +487,7 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
480487
next_group_ = 0;
481488
std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) {
482489
group.pending_ = group.tensor_indices_.size();
490+
group.sparse_contents_ = Tensor();
483491
});
484492

485493
// reinitialize vars_marked_ready_ for next iteration
@@ -544,9 +552,6 @@ void EagerReducer::AddDistHook(size_t var_index) {
544552
return;
545553
}
546554

547-
auto &tensor = tensors_[var_index];
548-
const auto &grad_node = GetGradNodeFromTensor(&tensor);
549-
550555
VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
551556
<< "@Grad] arrived and triggered disthook";
552557

@@ -608,33 +613,69 @@ void EagerReducer::MarkVarReady(const size_t var_index,
608613
auto &group_tensor = group.dense_tensors_[inside_group_index];
609614
const auto length = group.length_[inside_group_index];
610615

611-
if (is_used_var) {
612-
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
613-
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
614-
group_tensor
615-
.ShareDataWith(
616-
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
617-
.Resize({grad_tensor.numel()});
618-
} else {
619-
// TODO(shenliang03): maybe save the memory by avoiding tensor construction
620-
if (!group_tensor.initialized()) {
621-
group_tensor.Resize({static_cast<int64_t>(length)});
622-
group_tensor.mutable_data(inner_place_, group.dtype_);
623-
}
624-
if (HasGrad(var_index)) {
625-
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
626-
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
616+
if (!group.is_sparse_) {
617+
if (is_used_var) {
618+
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
619+
auto &grad_tensor =
620+
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
627621
group_tensor
628622
.ShareDataWith(*(
629-
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor->impl())))
630-
.Resize({length});
623+
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
624+
.Resize({grad_tensor.numel()});
631625
} else {
632-
VLOG(3) << "Tensor[" << tensors_[var_index].name()
633-
<< "] doesn't have grad";
634-
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
635-
group_tensor.Resize({static_cast<int64_t>(length)});
636-
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
626+
// TODO(shenliang03): maybe save the memory by avoiding tensor
627+
// construction
628+
if (!group_tensor.initialized()) {
629+
group_tensor.Resize({static_cast<int64_t>(length)});
630+
group_tensor.mutable_data(inner_place_, group.dtype_);
631+
}
632+
if (HasGrad(var_index)) {
633+
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
634+
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
635+
group_tensor
636+
.ShareDataWith(*(std::dynamic_pointer_cast<phi::DenseTensor>(
637+
grad_tensor->impl())))
638+
.Resize({length});
639+
} else {
640+
VLOG(3) << "Tensor[" << tensors_[var_index].name()
641+
<< "] doesn't have grad";
642+
auto *dev_ctx =
643+
platform::DeviceContextPool::Instance().Get(inner_place_);
644+
group_tensor.Resize({static_cast<int64_t>(length)});
645+
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
646+
}
637647
}
648+
} else {
649+
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
650+
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
651+
652+
// process sparse group
653+
PADDLE_ENFORCE_EQ(
654+
HasGrad(var_index), true,
655+
platform::errors::PreconditionNotMet(
656+
"The sparse parameter[%d][%s] should have gradient. "
657+
"Currently, DataParallel does not support sparse "
658+
"parameters without generating gradients during training. "
659+
"For example, if is_sparese=True is used in Embedding, "
660+
"the current step of this parameter cannot generate gradient "
661+
"because of stop_gradient/detatch, where error will occur.",
662+
var_index, tensors_[var_index].name()));
663+
664+
// need to check tensor type
665+
PADDLE_ENFORCE_EQ(
666+
grad_tensor.is_selected_rows(), true,
667+
platform::errors::PreconditionNotMet(
668+
"The sparse parameter[%d][%s] must have a selectedrows gradient. "
669+
"Before forward pass, the parameter type is inferred to be "
670+
"SelectedRows, but after backward pass, its actual type becomes "
671+
"LodTensor. It is currently not supported by DataParallel. "
672+
"For example, if sparse embedding is used, and the weight of "
673+
"embedding is shared with subsequent dense parameters, then "
674+
"the parameter gradient of the embedding will be converted "
675+
"to dense parameters.",
676+
var_index, tensors_[var_index].name()));
677+
678+
group.sparse_contents_.set_impl(grad_tensor.impl());
638679
}
639680

640681
if (--group.pending_ == 0) {
@@ -666,7 +707,11 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
666707
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
667708
++next_group_) {
668709
UNUSED auto &group = groups_[next_group_];
669-
FusedAllReduceSchedule(&group, next_group_);
710+
if (group.is_sparse_) {
711+
AllReduceSparse(&group, next_group_);
712+
} else {
713+
FusedAllReduceSchedule(&group, next_group_);
714+
}
670715
}
671716
}
672717

@@ -725,6 +770,11 @@ void EagerReducer::ProcessUnusedDenseVars() {
725770
const auto inside_group_index = var_locator.inside_group_index;
726771
auto &src_tensor = group.dense_tensors_[inside_group_index];
727772

773+
// sparse no need to check and no support find_unused_parameters
774+
if (group.is_sparse_) {
775+
continue;
776+
}
777+
728778
Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));
729779

730780
auto dest_var_base = tensors_[var_index];
@@ -739,11 +789,15 @@ void EagerReducer::FinalizeBackward() {
739789
groups_need_finalize_ = false;
740790
grad_need_hooks_ = false;
741791
for (auto &group : groups_) {
742-
group.task->Synchronize();
792+
if (!group.is_sparse_) {
793+
group.task->Synchronize();
794+
}
743795
}
744796

745797
for (auto &group : groups_) {
746-
group.SplitTensors(inner_place_);
798+
if (!group.is_sparse_) {
799+
group.SplitTensors(inner_place_);
800+
}
747801
}
748802

749803
if (find_unused_vars_each_step_) {
@@ -778,6 +832,127 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
778832
// split in FinalizeBackward()
779833
}
780834

835+
void EagerReducer::AllReduceSparse(EagerGroup *group,
836+
const int curr_group_index) {
837+
// div nranks
838+
Tensor sparse_tensor(group->sparse_contents_);
839+
paddle::experimental::scale_(sparse_tensor, 1.0 / nranks_, 0.0, false);
840+
841+
VLOG(3) << "sparse_group [" << curr_group_index << "] start allreduce.";
842+
843+
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
844+
if (platform::is_gpu_place(inner_place_)) {
845+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
846+
dev_ctx = static_cast<platform::CUDADeviceContext *>(
847+
platform::DeviceContextPool::Instance().Get(inner_place_));
848+
#else
849+
PADDLE_THROW(platform::errors::PermissionDenied(
850+
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
851+
"Please recompile or reinstall Paddle with NCCL support."));
852+
#endif
853+
} else if (platform::is_cpu_place(inner_place_)) {
854+
dev_ctx = static_cast<platform::CPUDeviceContext *>(
855+
platform::DeviceContextPool::Instance().Get(inner_place_));
856+
} else {
857+
PADDLE_THROW(platform::errors::Unimplemented(
858+
"Split grad tensor not supported on place (%s)", inner_place_));
859+
}
860+
861+
auto src = std::dynamic_pointer_cast<phi::SelectedRows>(
862+
group->sparse_contents_.impl());
863+
const auto &src_rows = src->rows();
864+
865+
const auto &rank_ = process_group_->GetRank();
866+
const auto &size_ = process_group_->GetSize();
867+
868+
framework::Vector<int64_t> rows_num_vector(size_);
869+
rows_num_vector[rank_] = static_cast<int64_t>(src_rows.size());
870+
871+
Tensor rows_num_tensor = paddle::experimental::empty(
872+
IntArray({static_cast<int64_t>(size_)}), DataType::INT64, inner_place_);
873+
auto *rows_num_dense_tensor =
874+
std::dynamic_pointer_cast<phi::DenseTensor>(rows_num_tensor.impl()).get();
875+
framework::TensorFromVector<int64_t>(rows_num_vector, *dev_ctx,
876+
rows_num_dense_tensor);
877+
878+
distributed::AllreduceOptions opts;
879+
opts.reduce_op = ReduceOp::SUM;
880+
std::vector<Tensor> reduce_tensors = {rows_num_tensor};
881+
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
882+
883+
framework::TensorToVector<int64_t>(*rows_num_dense_tensor, *dev_ctx,
884+
&rows_num_vector);
885+
dev_ctx->Wait();
886+
887+
const auto *cpu_rows_num_ptr = rows_num_vector.data();
888+
auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
889+
static_cast<int64_t>(0));
890+
891+
VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
892+
<< ", total rows number: " << rows_num
893+
<< ", height: " << src->height();
894+
895+
dev_ctx->Wait();
896+
897+
if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
898+
[&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
899+
// During sparse communication, the number of each card is same.
900+
// allgather is used to speed up the allreduce by replacing broadcast.
901+
902+
VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
903+
904+
Tensor dst_rows_tensor =
905+
paddle::experimental::empty(IntArray({static_cast<int64_t>(rows_num)}),
906+
DataType::INT64, inner_place_);
907+
Tensor src_rows_tensor = paddle::experimental::empty(
908+
IntArray({static_cast<int64_t>((*src).rows().size())}), DataType::INT64,
909+
inner_place_);
910+
auto *src_rows_dense_tensor =
911+
std::dynamic_pointer_cast<phi::DenseTensor>(src_rows_tensor.impl())
912+
.get();
913+
framework::TensorFromVector<int64_t>((*src).rows(), *dev_ctx,
914+
src_rows_dense_tensor);
915+
916+
std::vector<Tensor> src_rows_tensors = {src_rows_tensor};
917+
std::vector<Tensor> dst_rows_tensors = {dst_rows_tensor};
918+
process_group_->AllGather(src_rows_tensors, dst_rows_tensors)
919+
->Synchronize();
920+
921+
framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
922+
auto *dst_rows_dense_tensor =
923+
std::dynamic_pointer_cast<phi::DenseTensor>(dst_rows_tensor.impl())
924+
.get();
925+
framework::TensorToVector<int64_t>(*dst_rows_dense_tensor, *dev_ctx,
926+
&dst_rows_vector);
927+
dev_ctx->Wait();
928+
929+
Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
930+
std::vector<int64_t> dst_shape = src_value_tensor.shape();
931+
dst_shape[dst_shape.size() - 2] = rows_num;
932+
auto dst_dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
933+
paddle::experimental::full(IntArray(dst_shape), 0,
934+
src_value_tensor.dtype(), inner_place_)
935+
.impl());
936+
937+
auto dst =
938+
std::make_shared<phi::SelectedRows>(dst_rows_vector, (*src).height());
939+
*(dst->mutable_value()) = *dst_dense_tensor;
940+
Tensor dst_value_tensor(std::make_shared<phi::DenseTensor>(dst->value()));
941+
942+
std::vector<Tensor> src_value_tensors = {src_value_tensor};
943+
std::vector<Tensor> dst_value_tensors = {dst_value_tensor};
944+
process_group_->AllGather(src_value_tensors, dst_value_tensors)
945+
->Synchronize();
946+
947+
src->set_rows(dst_rows_vector);
948+
*(src->mutable_value()) =
949+
*(std::dynamic_pointer_cast<phi::DenseTensor>(dst_value_tensor.impl()));
950+
} else {
951+
PADDLE_THROW(
952+
platform::errors::Unimplemented("This case is not supported."));
953+
}
954+
}
955+
781956
std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
782957
const auto &tensors_ = group.tensor_indices_;
783958
out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size()

paddle/fluid/distributed/collective/reducer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
4747
class EagerGroup {
4848
public:
4949
Tensor dense_contents_;
50+
Tensor sparse_contents_;
51+
bool is_sparse_ = false;
5052

5153
// for concat kernel
5254
std::vector<phi::DenseTensor> dense_tensors_;
@@ -104,6 +106,7 @@ class EagerReducer {
104106
void MarkVarReady(const size_t var_index, const bool is_used_var);
105107
void MarkGroupReady(const size_t group_index);
106108
void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index);
109+
void AllReduceSparse(EagerGroup *group, const int curr_group_index);
107110
void FinalizeBackward();
108111
void TraverseBackwardGraph(const std::vector<Tensor> &outputs);
109112
void ProcessUnusedDenseVars();

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
11281128
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
11291129
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
11301130
set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120)
1131-
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 150)
1131+
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 300)
11321132
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200)
11331133
set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150)
11341134
set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150)
@@ -1153,8 +1153,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
11531153
set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300)
11541154

11551155
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
1156-
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
1157-
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120)
1156+
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200)
1157+
set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height PROPERTIES TIMEOUT 150)
1158+
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 150)
11581159
endif()
11591160
endif()
11601161

python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(self,
4242
dtype=dtype,
4343
is_sparse=is_sparse,
4444
param_attr=fluid.ParamAttr(
45-
name='embedding_param',
4645
initializer=fluid.initializer.UniformInitializer(
4746
low=-init_scale, high=init_scale)))
4847
self.softmax_weight = self.create_parameter(
@@ -103,8 +102,8 @@ def get_model(self):
103102
train_reader = paddle.batch(
104103
fake_sample_reader(), batch_size=batch_size, drop_last=True)
105104

106-
optimizer = fluid.optimizer.SGD(learning_rate=0.001,
107-
parameter_list=model.parameters())
105+
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
106+
parameters=model.parameters())
108107

109108
return model, train_reader, optimizer
110109

python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self,
4040
self.hidden_size,
4141
sparse=True,
4242
weight_attr=paddle.ParamAttr(
43-
name='embedding_param',
4443
initializer=paddle.nn.initializer.Uniform(
4544
low=-init_scale, high=init_scale)))
4645
self.softmax_weight = self.create_parameter(

python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(self,
3939
self.hidden_size,
4040
sparse=is_sparse,
4141
weight_attr=paddle.ParamAttr(
42-
name='embedding_param',
4342
initializer=paddle.nn.initializer.Uniform(
4443
low=-init_scale, high=init_scale)))
4544
self.softmax_weight = self.create_parameter(

0 commit comments

Comments
 (0)