|
14 | 14 |
|
15 | 15 | #include "paddle/phi/kernels/dropout_kernel.h" |
16 | 16 |
|
17 | | -#include "paddle/fluid/framework/generator.h" |
18 | 17 | #include "paddle/phi/backends/cpu/cpu_context.h" |
| 18 | +#include "paddle/phi/core/generator.h" |
19 | 19 | #include "paddle/phi/core/kernel_registry.h" |
20 | 20 | #include "paddle/phi/kernels/expand_kernel.h" |
21 | 21 | #include "paddle/phi/kernels/funcs/eigen/common.h" |
@@ -82,7 +82,13 @@ void DropoutRawKernel(const Context& dev_ctx, |
82 | 82 | } else { |
83 | 83 | seed_data = fix_seed ? seed : 0; |
84 | 84 | } |
85 | | - auto engine = paddle::framework::GetCPURandomEngine(seed_data); |
| 85 | + std::shared_ptr<std::mt19937_64> engine; |
| 86 | + if (seed_data) { |
| 87 | + engine = std::make_shared<std::mt19937_64>(); |
| 88 | + engine->seed(seed_data); |
| 89 | + } else { |
| 90 | + engine = dev_ctx.GetGenerator()->GetCPUEngine(); |
| 91 | + } |
86 | 92 |
|
87 | 93 | std::uniform_real_distribution<float> dist(0, 1); |
88 | 94 |
|
@@ -147,7 +153,13 @@ void DropoutNdKernel(const Context& dev_ctx, |
147 | 153 | } else { |
148 | 154 | seed_data = fix_seed ? seed : 0; |
149 | 155 | } |
150 | | - auto engine = paddle::framework::GetCPURandomEngine(seed_data); |
| 156 | + std::shared_ptr<std::mt19937_64> engine; |
| 157 | + if (seed_data) { |
| 158 | + engine = std::make_shared<std::mt19937_64>(); |
| 159 | + engine->seed(seed_data); |
| 160 | + } else { |
| 161 | + engine = dev_ctx.GetGenerator()->GetCPUEngine(); |
| 162 | + } |
151 | 163 |
|
152 | 164 | std::uniform_real_distribution<float> dist(0, 1); |
153 | 165 |
|
|
0 commit comments