Skip to content

Commit a2ef2c4

Browse files
committed
move elementwise_max/min/mod into phi
1 parent b1a4668 commit a2ef2c4

26 files changed

+610
-627
lines changed

paddle/fluid/framework/new_executor/standalone_executor_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ USE_OP(sum);
5454
USE_OP_ITSELF(slice_grad);
5555
USE_OP_ITSELF(lookup_table_grad);
5656
USE_OP(sqrt);
57-
USE_OP(elementwise_max);
57+
USE_OP_ITSELF(elementwise_max);
5858
USE_OP_ITSELF(elementwise_div);
5959
USE_OP_ITSELF(sgd);
6060
USE_OP(squared_l2_norm);

paddle/fluid/operators/elementwise/elementwise_functor.h

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -70,75 +70,29 @@ struct InverseFloorDivFunctor {
7070

7171
// Maximum
7272
template <typename T>
73-
struct MaxFunctor {
74-
inline HOSTDEVICE T operator()(const T a, const T b) const {
75-
return a > b ? a : b;
76-
}
77-
};
73+
using MaxFunctor = phi::funcs::MaximumFunctor<T>;
7874

7975
// Minmum
8076
template <typename T>
81-
struct MinFunctor {
82-
inline HOSTDEVICE T operator()(const T a, const T b) const {
83-
return a < b ? a : b;
84-
}
85-
};
77+
using MinFunctor = phi::funcs::MinimumFunctor<T>;
8678

8779
template <typename T>
8880
using Complex = paddle::platform::complex<T>;
8981

82+
// Ternary compare
9083
template <typename T>
91-
struct MinGradXFunctor {
92-
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
93-
return dout * static_cast<T>(x < y);
94-
}
95-
};
84+
using MaxGradXFunctor = phi::funcs::MaxGradXFunctor<T>;
9685
template <typename T>
97-
struct MinGradYFunctor {
98-
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
99-
return dout * static_cast<T>(x >= y);
100-
}
101-
};
102-
86+
using MaxGradYFunctor = phi::funcs::MaxGradYFunctor<T>;
10387
template <typename InT, typename OutT>
104-
struct MinGradXYFunctor {
105-
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
106-
const InT dout) {
107-
phi::Array<OutT, 2> outs;
108-
// dx = dout * (x < y)
109-
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x < y));
110-
// dy = dout * (x >= y)
111-
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x >= y));
112-
return outs;
113-
}
114-
};
88+
using MaxGradXYFunctor = phi::funcs::MaxGradXYFunctor<InT, OutT>;
11589

116-
// Ternary compare
11790
template <typename T>
118-
struct MaxGradXFunctor {
119-
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
120-
return dout * static_cast<T>(x > y);
121-
}
122-
};
91+
using MinGradXFunctor = phi::funcs::MinGradXFunctor<T>;
12392
template <typename T>
124-
struct MaxGradYFunctor {
125-
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
126-
return dout * static_cast<T>(x <= y);
127-
}
128-
};
129-
93+
using MinGradYFunctor = phi::funcs::MinGradYFunctor<T>;
13094
template <typename InT, typename OutT>
131-
struct MaxGradXYFunctor {
132-
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
133-
const InT dout) {
134-
phi::Array<OutT, 2> outs;
135-
// dx = dout * (x > y)
136-
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
137-
// dy = dout * (x <= y)
138-
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
139-
return outs;
140-
}
141-
};
95+
using MinGradXYFunctor = phi::funcs::MinGradXYFunctor<InT, OutT>;
14296

14397
} // namespace operators
14498
} // namespace paddle

paddle/fluid/operators/elementwise/elementwise_max_op.cc

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
16-
1715
#include <string>
1816

1917
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
@@ -119,23 +117,6 @@ REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
119117

120118
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
121119

122-
REGISTER_OP_CPU_KERNEL(
123-
elementwise_max,
124-
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
125-
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, double>,
126-
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int>,
127-
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int64_t>,
128-
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext,
129-
paddle::platform::bfloat16>);
130-
REGISTER_OP_CPU_KERNEL(
131-
elementwise_max_grad,
132-
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
133-
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
134-
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
135-
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
136-
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext,
137-
paddle::platform::bfloat16>);
138-
139120
REGISTER_OP_VERSION(elementwise_max)
140121
.AddCheckpoint(
141122
R"ROC(Register elementwise_max for adding the attribute of Scale_y)ROC",

paddle/fluid/operators/elementwise/elementwise_max_op.cu

Lines changed: 0 additions & 88 deletions
This file was deleted.

paddle/fluid/operators/elementwise/elementwise_max_op.h

Lines changed: 0 additions & 93 deletions
This file was deleted.

paddle/fluid/operators/elementwise/elementwise_max_op_npu.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
1615
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
1716
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
1817

paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License. */
1414

1515
#ifdef PADDLE_WITH_XPU
1616

17-
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
1817
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
1918
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
2019
namespace paddle {

paddle/fluid/operators/elementwise/elementwise_min_op.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
16-
1715
#include <string>
1816

1917
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
@@ -119,19 +117,6 @@ REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
119117

120118
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
121119

122-
REGISTER_OP_CPU_KERNEL(
123-
elementwise_min,
124-
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
125-
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, double>,
126-
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, int>,
127-
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, int64_t>);
128-
REGISTER_OP_CPU_KERNEL(
129-
elementwise_min_grad,
130-
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, float>,
131-
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, double>,
132-
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int>,
133-
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
134-
135120
REGISTER_OP_VERSION(elementwise_min)
136121
.AddCheckpoint(
137122
R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC",

0 commit comments

Comments
 (0)