Skip to content

Commit 830debc

Browse files
Add functor_primitives.h for kernel primtive api (#36203)
* Add functor_primitives.h for kernel primtive api * update * move namespace kps * subFunctor init_data * delete InvalidArgumentError
1 parent eaeeb88 commit 830debc

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
namespace paddle {
18+
namespace operators {
19+
namespace kernel_primitives {
20+
namespace details {
21+
22+
static __device__ __forceinline__ platform::float16 Exp(platform::float16 x) {
23+
return ::Eigen::numext::exp(x);
24+
}
25+
26+
static __device__ __forceinline__ float Exp(float x) { return expf(x); }
27+
28+
static __device__ __forceinline__ double Exp(double x) { return exp(x); }
29+
30+
static __device__ __forceinline__ platform::float16 Log(platform::float16 x) {
31+
return ::Eigen::numext::log(x);
32+
}
33+
34+
static __device__ __forceinline__ float Log(float x) { return logf(x); }
35+
36+
static __device__ __forceinline__ double Log(double x) { return log(x); }
37+
38+
} // namespace details
39+
40+
/******************************** Unary Functor *******************************/
41+
42+
/**
43+
* @brief Default unary exp functor
44+
*/
45+
template <typename Tx, typename Ty = Tx>
46+
struct ExpFunctor {
47+
HOSTDEVICE inline ExpFunctor() {}
48+
49+
HOSTDEVICE explicit inline ExpFunctor(int n) {}
50+
51+
HOSTDEVICE inline Ty operator()(const Tx& x) const {
52+
return static_cast<Ty>(details::Exp(x));
53+
}
54+
};
55+
56+
/**
57+
* @brief Default unary identity functor
58+
*/
59+
template <typename Tx, typename Ty = Tx>
60+
struct IdentityFunctor {
61+
HOSTDEVICE inline IdentityFunctor() {}
62+
63+
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
64+
65+
HOSTDEVICE inline Ty operator()(const Tx& x) const {
66+
return static_cast<Ty>(x);
67+
}
68+
};
69+
70+
/**
71+
* @brief Default unary div functor. Divide by a constant
72+
*/
73+
template <typename Tx, typename Ty = Tx>
74+
struct DivideFunctor {
75+
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); }
76+
77+
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {}
78+
79+
HOSTDEVICE inline Ty operator()(const Tx& x) const {
80+
return static_cast<Ty>(x * n_inv);
81+
}
82+
83+
private:
84+
Tx n_inv;
85+
};
86+
87+
/**
88+
* @brief Default unary square functor
89+
*/
90+
template <typename Tx, typename Ty = Tx>
91+
struct SquareFunctor {
92+
HOSTDEVICE inline SquareFunctor() {}
93+
94+
HOSTDEVICE explicit inline SquareFunctor(int n) {}
95+
96+
HOSTDEVICE inline Ty operator()(const Tx& x) const {
97+
return static_cast<Ty>(x) * static_cast<Ty>(x);
98+
}
99+
};
100+
101+
/****************************** Binary Functor ********************************/
102+
103+
/**
104+
* @brief Default binary min functor
105+
*/
106+
template <typename T>
107+
struct MinFunctor {
108+
inline T initial() { return static_cast<T>(std::numeric_limits<T>::max()); }
109+
110+
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
111+
return (b < a) ? b : a;
112+
}
113+
};
114+
115+
/**
116+
* @brief Default binary max functor
117+
*/
118+
template <typename T>
119+
struct MaxFunctor {
120+
inline T initial() {
121+
return static_cast<T>(std::numeric_limits<T>::lowest());
122+
}
123+
124+
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
125+
return (b > a) ? b : a;
126+
}
127+
};
128+
129+
/**
130+
* @brief Default binary add functor
131+
*/
132+
template <typename T>
133+
struct AddFunctor {
134+
inline T initial() { return static_cast<T>(0.0f); }
135+
136+
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
137+
return b + a;
138+
}
139+
};
140+
141+
/**
142+
* @brief Default binary add functor
143+
*/
144+
template <typename T>
145+
struct MulFunctor {
146+
inline T initial() { return static_cast<T>(1.0f); }
147+
148+
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
149+
return b * a;
150+
}
151+
};
152+
153+
/**
154+
* @brief Default binary logic or functor
155+
*/
156+
template <typename T>
157+
struct LogicalOrFunctor {
158+
inline T initial() { return static_cast<T>(false); }
159+
160+
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
161+
return b || a;
162+
}
163+
};
164+
165+
/**
166+
* @brief Default binary logic and functor
167+
*/
168+
template <typename T>
169+
struct LogicalAndFunctor {
170+
inline T initial() { return static_cast<T>(true); }
171+
172+
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
173+
return b && a;
174+
}
175+
};
176+
177+
/**
178+
* @brief Default binary sub functor
179+
*/
180+
template <typename T>
181+
struct SubFunctor {
182+
inline T initial() { return static_cast<T>(0.0f); }
183+
184+
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; }
185+
};
186+
187+
/**
188+
* @brief Default binary div functor
189+
*/
190+
template <typename T, typename Enable = void>
191+
struct DivFunctor {
192+
inline T initial() { return static_cast<T>(1.0f); }
193+
194+
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
195+
};
196+
197+
template <typename T>
198+
struct DivFunctor<T,
199+
typename std::enable_if<std::is_integral<T>::value>::type> {
200+
inline T initial() { return static_cast<T>(1.0f); }
201+
202+
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
203+
// For int32/int64, need to check whether the divison is zero.
204+
PADDLE_ENFORCE_NE(b, 0,
205+
platform::errors::InvalidArgument(
206+
"Integer division by zero encountered "
207+
"in (floor) divide. Please check the input value."));
208+
return a / b;
209+
}
210+
};
211+
212+
/**
213+
* @brief Default binary floor divide functor
214+
*/
215+
template <typename T>
216+
struct FloorDivFunctor {
217+
inline T initial() { return static_cast<T>(1.0f); }
218+
219+
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
220+
PADDLE_ENFORCE_NE(b, 0,
221+
platform::errors::InvalidArgument(
222+
"Integer division by zero encountered "
223+
"in (floor) divide. Please check the input value."));
224+
return static_cast<T>(std::trunc(a / b));
225+
}
226+
};
227+
228+
} // namespace kernel_primitives
229+
} // namespace operators
230+
} // namespace paddle

paddle/fluid/operators/kernel_primitives/kernel_primitives.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
1818
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
19+
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
1920
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
2021

2122
namespace paddle {

0 commit comments

Comments
 (0)