Skip to content
Merged
37 changes: 37 additions & 0 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,43 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return default_cpu_generator;
}

using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;

static RNGMap& GetRandomSeedGeneratorMap() {
static auto random_seed_generator_map = RNGMap();
return random_seed_generator_map;
}

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter == rng_map.end(), true,
platform::errors::AlreadyExists(
"%s RandomSeedGenerator is already exist", name));

auto generator = std::make_shared<Generator>(seed);
bool emplace_success = rng_map.emplace(name, generator).second;
PADDLE_ENFORCE_EQ(
emplace_success, true,
platform::errors::PermissionDenied(
"SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator",
name));
return rng_map[name];
}

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name) {
auto& rng_map = GetRandomSeedGeneratorMap();
auto iter = rng_map.find(name);
PADDLE_ENFORCE_EQ(iter != rng_map.end(), true,
platform::errors::NotFound(
"%s RandomSeedGenerator is not found, please "
"use `set_random_seed_generator` to set rng first",
name));
return iter->second;
}

std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
return op_default_cpu_engine;
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);

const std::shared_ptr<Generator>& SetRandomSeedGenerator(
const std::string& name, uint64_t seed);

const std::shared_ptr<Generator>& GetRandomSeedGenerator(
const std::string& name);

} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
config.thread_per_block.x * vec_size) +
1) *
vec_size;

GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
&seed_data, &increment);

Expand Down
10 changes: 3 additions & 7 deletions paddle/fluid/operators/dropout_impl_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

if ((seed) && platform::is_gpu_place(seed->place())) {
if (seed) {
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
Expand All @@ -39,12 +39,8 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
*seed_data = seed_offset.first;
*increment = seed_offset.second;
} else {
if (seed) {
*seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
}
std::random_device rnd;
*seed_data = is_fix_seed ? seed_val : rnd();
*increment = offset;
}
}
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ class DropoutOp : public framework::OperatorWithKernel {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "Seed") {
VLOG(10) << "var_name:" << var_name
<< " does not need to transform in dropout op";
return expected_kernel_type;
}

return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/operators/seed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddOutput("Out", "The output of seed op.");
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<bool>("deterministic",
"(bool, default false) Whether to use deterministic "
"RandomSeedGenerator which "
"generate by `set_random_seed_generator`")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"rng_name",
"use deterministic RandomSeedGenerator which name is `rng_name`")
.SetDefault("")
.AsExtra();
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC(
Seed Operator.
)DOC");
Expand All @@ -55,3 +72,15 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
seed, ops::CPUSeedKernel<paddle::platform::CPUDeviceContext, int>);

/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(seed)
.AddCheckpoint(
R"ROC(
Upgrade seed add a new attribute [force_cpu])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"force_cpu",
"If true, Force fill output variable to cpu."
"memory. Otherwise, fill output variable to the running "
"device",
false));
35 changes: 21 additions & 14 deletions paddle/fluid/operators/seed_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/seed_op.h"

namespace paddle {
Expand All @@ -20,22 +21,28 @@ namespace operators {
template <typename Place, typename T>
class GPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
int user_seed = context.Attr<int>("seed");
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<Tensor>("Out");
int seed = get_seed(context);

auto force_cpu = context.Attr<bool>("force_cpu");
bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace();
if (cpu_place) {
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(context.GetPlace());
out->mutable_data<T>(platform::CPUPlace());
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
out, static_cast<T>(seed));
} else {
seed = rnd();
auto *out_data = out->mutable_data<T>(context.GetPlace());
auto target_gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace());
auto stream = context.cuda_device_context().stream();
memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed,
sizeof(int), stream);
}
auto target_gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace());
auto stream = context.cuda_device_context().stream();
memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed,
sizeof(int), stream);
}
};

Expand Down
35 changes: 25 additions & 10 deletions paddle/fluid/operators/seed_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,45 @@
// limitations under the License.
#pragma once

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
int user_seed = context.Attr<int>("seed");
static int get_seed(const framework::ExecutionContext& context) {
int user_seed = context.Attr<int>("seed");
bool deterministic = context.Attr<bool>("deterministic");

int seed = 0;
if (!deterministic) {
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
std::random_device rnd;
seed = rnd();
}
out_data[0] = seed;
} else {
std::string name = context.Attr<std::string>("rng_name");
auto rng = framework::GetRandomSeedGenerator(name);
do { // NOTE(wangxi): cpu dropout will use random seed if seed == 0
seed = static_cast<int>(rng->Random64());
} while (seed == 0);
}
return seed;
}

template <typename DeviceContext, typename T>
class CPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
out_data[0] = get_seed(context);
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/generator_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void BindGenerator(py::module* m_ptr) {
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator);
m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator);
}
} // namespace pybind
} // namespace paddle
Loading