Skip to content

Commit e3e6a9f

Browse files
authored
[Prim][PIR] Add pd.rint op (#74012)
1 parent 8d47b98 commit e3e6a9f

File tree

13 files changed

+154
-2
lines changed

13 files changed

+154
-2
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ OP_SAME_OPERANDS_AND_RESULT(Relu)
139139
OP_SAME_OPERANDS_AND_RESULT(Relu6)
140140
OP_SAME_OPERANDS_AND_RESULT(Relu_)
141141
OP_SAME_OPERANDS_AND_RESULT(Reverse)
142+
OP_SAME_OPERANDS_AND_RESULT(Rint)
143+
OP_SAME_OPERANDS_AND_RESULT(Rint_)
142144
OP_SAME_OPERANDS_AND_RESULT(Roll)
143145
OP_SAME_OPERANDS_AND_RESULT(Round)
144146
OP_SAME_OPERANDS_AND_RESULT(Round_)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu)
131131
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu6)
132132
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu_)
133133
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reverse)
134+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rint)
135+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rint_)
134136
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Roll)
135137
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Round)
136138
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Round_)

paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1525,8 +1525,14 @@ std::tuple<Tensor, Tensor, Tensor> rms_norm_decomp(
15251525
auto quant_max_bound_scalar =
15261526
full_scalar<T>(quant_max_bound, out.dtype(), out.place());
15271527
auto scale_out = out * quant_scale_scalar * quant_max_bound_scalar;
1528+
Tensor round_out;
1529+
if (quant_round_type == 0) {
1530+
round_out = backend::rint<T>(scale_out);
1531+
} else {
1532+
round_out = round<T>(scale_out);
1533+
}
15281534
auto clip_out = clip_decomp<T>(
1529-
round<T>(scale_out), quant_min_bound_scalar, quant_max_bound_scalar);
1535+
round_out, quant_min_bound_scalar, quant_max_bound_scalar);
15301536
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
15311537
out = cast<T>(clip_out, phi::DataType::INT8);
15321538
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {

paddle/phi/kernels/activation_grad_kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid);
302302
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt);
303303
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu6);
304304

305+
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Rint);
305306
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Round);
306307
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor);
307308
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil);

paddle/phi/kernels/cpu/activation_grad_kernel.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
150150
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
151151
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, SigmoidGradFunctor);
152152

153+
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Rint, ZeroGradFunctor);
153154
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, ZeroGradFunctor);
154155
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, ZeroGradFunctor);
155156
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor);
@@ -488,6 +489,15 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
488489
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad,
489490
CeluDoubleGradKernel)
490491

492+
PD_REGISTER_KERNEL(rint_grad,
493+
CPU,
494+
ALL_LAYOUT,
495+
phi::RintGradKernel,
496+
float,
497+
double,
498+
int,
499+
int64_t) {}
500+
491501
PD_REGISTER_KERNEL(round_grad,
492502
CPU,
493503
ALL_LAYOUT,

paddle/phi/kernels/cpu/activation_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor)
9797
DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
9898
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
9999
DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor)
100+
DEFINE_CPU_ACTIVATION_KERNEL(Rint, RintFunctor)
100101

101102
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, LogFunctor)
102103
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, Log2Functor)
@@ -257,6 +258,9 @@ PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
257258
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
258259
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
259260

261+
PD_REGISTER_KERNEL(
262+
rint, CPU, ALL_LAYOUT, phi::RintKernel, int, int64_t, float, double) {}
263+
260264
PD_REGISTER_KERNEL(round,
261265
CPU,
262266
ALL_LAYOUT,

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3005,6 +3005,26 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
30053005
}
30063006
};
30073007

3008+
// rint(x) = [x]
3009+
template <typename T, typename Enable = void>
3010+
struct RintFunctor : public BaseActivationFunctor<T> {
3011+
template <typename Device, typename X, typename Out>
3012+
void operator()(Device d, X x, Out out) const {
3013+
out.device(d) = x.unaryExpr([](const T& val) {
3014+
return (std::isnan(val) || std::isinf(val)) ? val : std::rint(val);
3015+
});
3016+
}
3017+
};
3018+
3019+
template <typename T>
3020+
struct RintFunctor<T, std::enable_if_t<std::is_integral_v<T>>>
3021+
: public BaseActivationFunctor<T> {
3022+
template <typename Device, typename X, typename Out>
3023+
void operator()(Device d, X x, Out out) const {
3024+
out.device(d) = x;
3025+
}
3026+
};
3027+
30083028
// round(x) = [x]
30093029
template <typename T, typename Enable = void>
30103030
struct RoundFunctor : public BaseActivationFunctor<T> {
@@ -5410,6 +5430,25 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
54105430
}
54115431
};
54125432

5433+
template <typename T, typename Enable = void>
5434+
struct CudaRintFunctor : public BaseActivationFunctor<T> {
5435+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
5436+
5437+
// rint(x) = rint(x)
5438+
__device__ __forceinline__ T operator()(const T arg_x) const {
5439+
MPType x = static_cast<MPType>(arg_x);
5440+
if (isnan(x) || isinf(x)) return arg_x;
5441+
return static_cast<T>(std::rint(x));
5442+
}
5443+
};
5444+
5445+
template <typename T>
5446+
struct CudaRintFunctor<T, std::enable_if_t<std::is_integral_v<T>>>
5447+
: public BaseActivationFunctor<T> {
5448+
// rint(x) = x
5449+
__device__ __forceinline__ T operator()(const T arg_x) const { return arg_x; }
5450+
};
5451+
54135452
template <typename T, typename Enable = void>
54145453
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
54155454
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

paddle/phi/kernels/gpu/activation_grad_kernel.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor);
178178
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor);
179179
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, CudaSigmoidGradFunctor);
180180

181+
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Rint, CudaZeroGradFunctor);
181182
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, CudaZeroGradFunctor);
182183
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, CudaZeroGradFunctor);
183184
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, CudaZeroGradFunctor);
@@ -558,6 +559,16 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
558559
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
559560
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel)
560561

562+
PD_REGISTER_KERNEL(rint_grad,
563+
GPU,
564+
ALL_LAYOUT,
565+
phi::RintGradKernel,
566+
int,
567+
int64_t,
568+
float,
569+
double,
570+
phi::dtype::float16,
571+
phi::dtype::bfloat16) {}
561572
PD_REGISTER_KERNEL(round_grad,
562573
GPU,
563574
ALL_LAYOUT,

paddle/phi/kernels/gpu/activation_kernel.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor)
116116
DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor)
117117
DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
118118
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)
119+
DEFINE_GPU_ACTIVATION_KERNEL(Rint, CudaRintFunctor)
119120

120121
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, CudaLogFunctor)
121122
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, CudaLog2Functor)
@@ -352,6 +353,16 @@ PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
352353
PD_REGISTER_ACTIVATION_KERNEL(selu, SeluKernel)
353354
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)
354355

356+
PD_REGISTER_KERNEL(rint,
357+
GPU,
358+
ALL_LAYOUT,
359+
phi::RintKernel,
360+
int,
361+
int64_t,
362+
float,
363+
double,
364+
phi::dtype::float16,
365+
phi::dtype::bfloat16) {}
355366
PD_REGISTER_KERNEL(round,
356367
GPU,
357368
ALL_LAYOUT,

paddle/phi/ops/yaml/backward.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,17 @@
28962896
output : Tensor(x_grad)
28972897
invoke : reverse(out_grad, axis)
28982898

2899+
- backward_op : rint_grad
2900+
forward : rint(Tensor x) -> Tensor(out)
2901+
args : (Tensor out_grad)
2902+
output : Tensor(x_grad)
2903+
infer_meta :
2904+
func : UnchangedInferMeta
2905+
param: [out_grad]
2906+
kernel :
2907+
func : rint_grad
2908+
inplace : (out_grad -> x_grad)
2909+
28992910
- backward_op : rms_norm_grad
29002911
forward : rms_norm (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) -> Tensor(out), Tensor(residual_out), Tensor(inv_var)
29012912
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, Tensor inv_var, Tensor out_grad, float epsilon, int begin_norm_axis, float quant_scale)

0 commit comments

Comments
 (0)