Skip to content

Commit ca39599

Browse files
committed
follow comments
1 parent 6b48fdc commit ca39599

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

paddle/pten/kernels/bernoulli_kernel.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,6 @@
1919

2020
namespace pten {
2121

22-
template <typename T>
23-
inline HOSTDEVICE T BernoulliFunctor(T p, T rand) {
24-
PADDLE_ENFORCE_LE(p,
25-
1.0,
26-
pten::errors::OutOfRange(
27-
"The probability should be <= 1, but got %f", p));
28-
PADDLE_ENFORCE_GE(p,
29-
0.0,
30-
pten::errors::OutOfRange(
31-
"The probability should be >= 0, but got %f", p));
32-
return static_cast<T>(rand < p);
33-
}
34-
3522
template <typename T, typename Context>
3623
void BernoulliKernel(const Context& ctx,
3724
const DenseTensor& x,

paddle/pten/kernels/cpu/bernoulli_kernel.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@
1919

2020
namespace pten {
2121

22+
template <typename T>
23+
inline T BernoulliFunctor(T p, T rand) {
24+
PADDLE_ENFORCE_LE(p,
25+
1.0,
26+
pten::errors::OutOfRange(
27+
"The probability should be <= 1, but got %f", p));
28+
PADDLE_ENFORCE_GE(p,
29+
0.0,
30+
pten::errors::OutOfRange(
31+
"The probability should be >= 0, but got %f", p));
32+
return static_cast<T>(rand < p);
33+
}
34+
2235
template <typename T, typename Context>
2336
void BernoulliKernel(const Context& ctx,
2437
const DenseTensor& x,

0 commit comments

Comments
 (0)