Skip to content

Commit 6b48fdc

Browse files
committed
move bernoulli kernel to pten
1 parent 3e7825f commit 6b48fdc

File tree

6 files changed

+206
-112
lines changed

6 files changed

+206
-112
lines changed

paddle/fluid/operators/bernoulli_op.cc

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,6 @@ class BernoulliOp : public framework::OperatorWithKernel {
4949
}
5050
};
5151

52-
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
53-
// Use std::random and thrust::random(thrust is a std library in CUDA) to
54-
// implement uniform random.
55-
template <typename T>
56-
class BernoulliOpKernel<platform::CPUDeviceContext, T>
57-
: public framework::OpKernel<T> {
58-
public:
59-
void Compute(const framework::ExecutionContext &ctx) const override {
60-
const auto x = ctx.Input<framework::Tensor>("X");
61-
auto out = ctx.Output<framework::Tensor>("Out");
62-
auto *in_data = x->data<T>();
63-
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
64-
65-
int64_t size = x->numel();
66-
std::uniform_real_distribution<T> dist(0.0, 1.0);
67-
auto gen_ptr = framework::DefaultCPUGenerator();
68-
auto engine = gen_ptr->GetCPUEngine();
69-
70-
for (int64_t i = 0; i < size; ++i) {
71-
out_data[i] = BernoulliFunctor(in_data[i], dist(*engine));
72-
}
73-
}
74-
}; // namespace operators
75-
7652
} // namespace operators
7753
} // namespace paddle
7854

@@ -82,7 +58,3 @@ REGISTER_OPERATOR(
8258
bernoulli, ops::BernoulliOp, ops::BernoulliOpMaker,
8359
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
8460
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
85-
86-
REGISTER_OP_CPU_KERNEL(bernoulli,
87-
ops::BernoulliOpKernel<plat::CPUDeviceContext, float>,
88-
ops::BernoulliOpKernel<plat::CPUDeviceContext, double>);

paddle/fluid/operators/bernoulli_op.cu

Lines changed: 0 additions & 84 deletions
This file was deleted.

paddle/fluid/platform/transform.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,53 @@ struct Transform<platform::CUDADeviceContext> {
141141
#endif
142142
}
143143
};
144+
145+
template <>
146+
struct Transform<pten::GPUContext> {
147+
template <typename InputIter, typename OutputIter, typename UnaryOperation>
148+
void operator()(const pten::GPUContext& context, InputIter first,
149+
InputIter last, OutputIter result, UnaryOperation op) {
150+
auto place = context.GetPlace();
151+
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
152+
platform::errors::PreconditionNotMet(
153+
"The CUDA Transform must be used in GPU place."));
154+
#ifdef __HIPCC__
155+
thrust::transform(thrust::hip::par.on(context.stream()),
156+
details::CastToCUDATransformIterator(first),
157+
details::CastToCUDATransformIterator(last),
158+
details::CastToCUDATransformIterator(result), op);
159+
#else
160+
thrust::transform(thrust::cuda::par.on(context.stream()),
161+
details::CastToCUDATransformIterator(first),
162+
details::CastToCUDATransformIterator(last),
163+
details::CastToCUDATransformIterator(result), op);
164+
#endif
165+
}
166+
167+
template <typename InputIter1, typename InputIter2, typename OutputIter,
168+
typename BinaryOperation>
169+
void operator()(const pten::GPUContext& context, InputIter1 first1,
170+
InputIter1 last1, InputIter2 first2, OutputIter result,
171+
BinaryOperation op) {
172+
auto place = context.GetPlace();
173+
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
174+
platform::errors::PreconditionNotMet(
175+
"The CUDA Transform must be used in GPU place."));
176+
#ifdef __HIPCC__
177+
thrust::transform(thrust::hip::par.on(context.stream()),
178+
details::CastToCUDATransformIterator(first1),
179+
details::CastToCUDATransformIterator(last1),
180+
details::CastToCUDATransformIterator(first2),
181+
details::CastToCUDATransformIterator(result), op);
182+
#else
183+
thrust::transform(thrust::cuda::par.on(context.stream()),
184+
details::CastToCUDATransformIterator(first1),
185+
details::CastToCUDATransformIterator(last1),
186+
details::CastToCUDATransformIterator(first2),
187+
details::CastToCUDATransformIterator(result), op);
188+
#endif
189+
}
190+
};
144191
#endif
145192

146193
} // namespace platform
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/pten/core/dense_tensor.h"
18+
#include "paddle/pten/core/device_context.h"
19+
20+
namespace pten {
21+
22+
template <typename T>
23+
inline HOSTDEVICE T BernoulliFunctor(T p, T rand) {
24+
PADDLE_ENFORCE_LE(p,
25+
1.0,
26+
pten::errors::OutOfRange(
27+
"The probability should be <= 1, but got %f", p));
28+
PADDLE_ENFORCE_GE(p,
29+
0.0,
30+
pten::errors::OutOfRange(
31+
"The probability should be >= 0, but got %f", p));
32+
return static_cast<T>(rand < p);
33+
}
34+
35+
template <typename T, typename Context>
36+
void BernoulliKernel(const Context& ctx,
37+
const DenseTensor& x,
38+
DenseTensor* out);
39+
40+
} // namespace pten
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/pten/kernels/bernoulli_kernel.h"
16+
#include <random>
17+
#include "paddle/pten/backends/cpu/cpu_context.h"
18+
#include "paddle/pten/core/kernel_registry.h"
19+
20+
namespace pten {
21+
22+
template <typename T, typename Context>
23+
void BernoulliKernel(const Context& ctx,
24+
const DenseTensor& x,
25+
DenseTensor* out) {
26+
auto numel = x.numel();
27+
auto* x_data = x.data<T>();
28+
T* out_data = ctx.template Alloc<T>(out);
29+
30+
std::uniform_real_distribution<T> dist(0.0, 1.0);
31+
auto gen_ptr = ctx.GetGenerator();
32+
auto engine = gen_ptr->GetCPUEngine();
33+
34+
for (int64_t i = 0; i < numel; ++i) {
35+
out_data[i] = BernoulliFunctor(x_data[i], dist(*engine));
36+
}
37+
}
38+
39+
} // namespace pten
40+
41+
PT_REGISTER_KERNEL(
42+
bernoulli, CPU, ALL_LAYOUT, pten::BernoulliKernel, float, double) {}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <thrust/execution_policy.h>
16+
#include <thrust/random.h>
17+
#include <thrust/transform.h>
18+
#include <algorithm>
19+
#include <vector>
20+
#include "paddle/pten/backends/gpu/gpu_context.h"
21+
#include "paddle/pten/core/dense_tensor.h"
22+
#include "paddle/pten/core/kernel_registry.h"
23+
#include "paddle/pten/kernels/bernoulli_kernel.h"
24+
25+
// See Note [ Why still include the fluid headers? ]
26+
#include "paddle/fluid/platform/transform.h"
27+
28+
namespace pten {
29+
30+
template <typename T>
31+
struct BernoulliCudaFunctor {
32+
unsigned int seed_;
33+
unsigned int offset_;
34+
__host__ __device__ BernoulliCudaFunctor(unsigned int seed,
35+
unsigned int offset)
36+
: seed_(seed), offset_(offset) {}
37+
38+
__host__ __device__ T operator()(const unsigned int n, const T p) const {
39+
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
40+
// lines of error messages if, and it should be refined.
41+
PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
42+
"The probability should be >=0 and <= 1, but got %f",
43+
p);
44+
thrust::minstd_rand rng;
45+
rng.seed(seed_);
46+
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
47+
rng.discard(n + offset_);
48+
return static_cast<T>(dist(rng) < p);
49+
}
50+
};
51+
52+
template <typename T, typename Context>
53+
void BernoulliKernel(const Context& ctx,
54+
const DenseTensor& x,
55+
DenseTensor* out) {
56+
auto numel = x.numel();
57+
auto* x_data = x.data<T>();
58+
T* out_data = ctx.template Alloc<T>(out);
59+
60+
auto gen_cuda = ctx.GetGenerator();
61+
auto seed_offset = gen_cuda->IncrementOffset(1);
62+
int64_t gen_offset = numel * seed_offset.second;
63+
paddle::platform::Transform<pten::GPUContext> trans;
64+
thrust::counting_iterator<int64_t> index_sequence_begin(0);
65+
trans(ctx,
66+
index_sequence_begin,
67+
index_sequence_begin + numel,
68+
x_data,
69+
out_data,
70+
BernoulliCudaFunctor<T>(static_cast<int64_t>(seed_offset.first),
71+
static_cast<int64_t>(gen_offset)));
72+
}
73+
74+
} // namespace pten
75+
76+
PT_REGISTER_KERNEL(
77+
bernoulli, GPU, ALL_LAYOUT, pten::BernoulliKernel, float, double) {}

0 commit comments

Comments
 (0)