Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,35 @@ class ProcessGroup {
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in /* tensors */, // NOLINT
std::vector<Tensor>& out /* tensors */) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors /* tensors */, // NOLINT
const ReduceOptions& opts) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Reduce", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */, // NOLINT
const ScatterOptions&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Scatter", GetBackendName()));
}

protected:
const int rank_;
const int size_;
Expand Down
143 changes: 143 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,148 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclAllGather(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), comm, stream);
},
CommType::ALLGATHER);
}

void* GetPointerByOffset(void* raw_pointer, size_t offset,
experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
}
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
offset += input_tensor->numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLREDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input_tensor->data(), output_tensor->data(), input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream));
},
CommType::REDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<Tensor>& in_tensors, std::vector<Tensor>& out_tensors,
const ScatterOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors, out_tensors,
[&](const Tensor& input, Tensor& output, ncclComm_t comm,
const gpuStream_t& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
size_t offset = 0;
if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input_tensor->data(), offset, input.type()),
input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), i, comm, stream));
offset += input_tensor->numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_tensor->data(), input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), opts.root_rank, comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_tensor->data(), input_tensor->numel() / size_,
platform::ToNCCLDataType(input.type()), opts.root_rank, comm,
stream));
}
},
CommType::SCATTER);
}

} // namespace distributed
} // namespace paddle
14 changes: 14 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Recv(std::vector<Tensor>& tensors,
int src_rank) override;

std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors) override;

std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in, std::vector<Tensor>& out) override;

std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) override;

std::shared_ptr<ProcessGroup::Task> Scatter(std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors,
const ScatterOptions&) override;

protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/collective/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,14 @@ struct BarrierOptions {
std::vector<int> place_ids;
};

struct ReduceOptions {
ReduceOp reduce_op = ReduceOp::SUM;
int root_rank = 0;
};

struct ScatterOptions {
int root_rank = 0;
};

} // namespace distributed
} // namespace paddle
57 changes: 57 additions & 0 deletions paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ void BindDistributed(py::module *m) {
.def(py::init<>())
.def_readwrite("place_ids", &distributed::BarrierOptions::place_ids);

py::class_<distributed::ReduceOptions>(*m, "ReduceOptions")
.def(py::init<>())
.def_readwrite("reduce_op", &distributed::ReduceOptions::reduce_op)
.def_readwrite("source_root", &distributed::ReduceOptions::root_rank);

auto ProcessGroup =
py::class_<distributed::ProcessGroup,
std::shared_ptr<distributed::ProcessGroup>>(*m, "ProcessGroup")
Expand Down Expand Up @@ -121,6 +126,58 @@ void BindDistributed(py::module *m) {
return self.Recv(tensors, src);
},
py::arg("tensor"), py::arg("src"),
py::call_guard<py::gil_scoped_release>())

.def("all_gather",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
std::vector<Tensor> in_tensors = {in_tensor};
std::vector<Tensor> out_tensors = {out_tensor};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"), py::arg("out"),
py::call_guard<py::gil_scoped_release>())

.def("alltoall",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
std::vector<Tensor> in_tensors = {in_tensor};
std::vector<Tensor> out_tensors = {out_tensor};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"), py::arg("out"),
py::call_guard<py::gil_scoped_release>())

.def("reduce",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
int dst, distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
distributed::ReduceOptions opts;
opts.reduce_op = op;
opts.root_rank = dst;
std::vector<Tensor> tensors = {in_tensor};
return self.Reduce(tensors, opts);
},
py::arg("tensor"), py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())

.def("scatter",
[](distributed::ProcessGroup &self, py::handle py_in_tensor,
py::handle py_out_tensor, int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
opts.root_rank = src;
std::vector<Tensor> in_tensors = {in_tensor};
std::vector<Tensor> out_tensors = {out_tensor};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"), py::arg("out"), py::arg("src"),
py::call_guard<py::gil_scoped_release>());

#if defined(PADDLE_WITH_NCCL)
Expand Down
Loading