Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ OP_SAME_OPERANDS_AND_RESULT(Relu)
OP_SAME_OPERANDS_AND_RESULT(Relu6)
OP_SAME_OPERANDS_AND_RESULT(Relu_)
OP_SAME_OPERANDS_AND_RESULT(Reverse)
OP_SAME_OPERANDS_AND_RESULT(Rint)
OP_SAME_OPERANDS_AND_RESULT(Rint_)
OP_SAME_OPERANDS_AND_RESULT(Roll)
OP_SAME_OPERANDS_AND_RESULT(Round)
OP_SAME_OPERANDS_AND_RESULT(Round_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu6)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reverse)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rint)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rint_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Roll)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Round)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Round_)
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1525,8 +1525,14 @@ std::tuple<Tensor, Tensor, Tensor> rms_norm_decomp(
auto quant_max_bound_scalar =
full_scalar<T>(quant_max_bound, out.dtype(), out.place());
auto scale_out = out * quant_scale_scalar * quant_max_bound_scalar;
Tensor round_out;
if (quant_round_type == 0) {
round_out = backend::rint<T>(scale_out);
} else {
round_out = round<T>(scale_out);
}
auto clip_out = clip_decomp<T>(
round<T>(scale_out), quant_min_bound_scalar, quant_max_bound_scalar);
round_out, quant_min_bound_scalar, quant_max_bound_scalar);
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
out = cast<T>(clip_out, phi::DataType::INT8);
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu6);

DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Rint);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Round);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil);
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, SigmoidGradFunctor);

DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Rint, ZeroGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, ZeroGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, ZeroGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor);
Expand Down Expand Up @@ -488,6 +489,15 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad,
CeluDoubleGradKernel)

PD_REGISTER_KERNEL(rint_grad,
CPU,
ALL_LAYOUT,
phi::RintGradKernel,
float,
double,
int,
int64_t) {}

PD_REGISTER_KERNEL(round_grad,
CPU,
ALL_LAYOUT,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Rint, RintFunctor)

DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, Log2Functor)
Expand Down Expand Up @@ -257,6 +258,9 @@ PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)

PD_REGISTER_KERNEL(
rint, CPU, ALL_LAYOUT, phi::RintKernel, int, int64_t, float, double) {}

PD_REGISTER_KERNEL(round,
CPU,
ALL_LAYOUT,
Expand Down
39 changes: 39 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3005,6 +3005,26 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
}
};

// rint(x) = [x]
template <typename T, typename Enable = void>
struct RintFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr([](const T& val) {
return (std::isnan(val) || std::isinf(val)) ? val : std::rint(val);
});
}
};

template <typename T>
struct RintFunctor<T, std::enable_if_t<std::is_integral_v<T>>>
: public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x;
}
};

// round(x) = [x]
template <typename T, typename Enable = void>
struct RoundFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -5410,6 +5430,25 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T, typename Enable = void>
struct CudaRintFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

// rint(x) = rint(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
if (isnan(x) || isinf(x)) return arg_x;
return static_cast<T>(std::rint(x));
}
};

template <typename T>
struct CudaRintFunctor<T, std::enable_if_t<std::is_integral_v<T>>>
: public BaseActivationFunctor<T> {
// rint(x) = x
__device__ __forceinline__ T operator()(const T arg_x) const { return arg_x; }
};

template <typename T, typename Enable = void>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, CudaSigmoidGradFunctor);

DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Rint, CudaZeroGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, CudaZeroGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, CudaZeroGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, CudaZeroGradFunctor);
Expand Down Expand Up @@ -558,6 +559,16 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel)

PD_REGISTER_KERNEL(rint_grad,
GPU,
ALL_LAYOUT,
phi::RintGradKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(round_grad,
GPU,
ALL_LAYOUT,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Rint, CudaRintFunctor)

DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, CudaLog2Functor)
Expand Down Expand Up @@ -352,6 +353,16 @@ PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(selu, SeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)

PD_REGISTER_KERNEL(rint,
GPU,
ALL_LAYOUT,
phi::RintKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(round,
GPU,
ALL_LAYOUT,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,17 @@
output : Tensor(x_grad)
invoke : reverse(out_grad, axis)

- backward_op : rint_grad
forward : rint(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [out_grad]
kernel :
func : rint_grad
inplace : (out_grad -> x_grad)

- backward_op : rms_norm_grad
forward : rms_norm (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) -> Tensor(out), Tensor(residual_out), Tensor(inv_var)
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, Tensor inv_var, Tensor out_grad, float epsilon, int begin_norm_axis, float quant_scale)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3206,6 +3206,13 @@
support_tensor : true
manual_signature : [reverse]

- op : rint
backward : rint_grad
inputs :
x : X
outputs :
out : Out
Comment on lines +3209 to +3214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是可以不加?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续提 PR 删除


- op : rmsprop_ (rmsprop)
inputs :
{param: Param, mean_square: MeanSquare, mean_grad: MeanGrad, learning_rate: LearningRate, grad: Grad, moment: Moment, master_param: MasterParam}
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4469,6 +4469,18 @@
backward : reverse_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface

- op : rint
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : rint
inplace : (x -> out)
backward : rint_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : rms_norm
args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
output : Tensor(out), Tensor(residual_out), Tensor(inv_var)
Expand Down
38 changes: 37 additions & 1 deletion test/ir/pir/cinn/test_cinn_fused_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TestFusedRmsNorm(unittest.TestCase):
def setUp(self):
paddle.seed(123)
self.init_data()
self.modify_data()
self.func = paddle.incubate.nn.functional.fused_rms_norm

def tearDown(self):
Expand Down Expand Up @@ -51,6 +52,9 @@ def init_data(self):
self.quant_max_bound = 0
self.quant_min_bound = 0

def modify_data(self):
pass

def inputs(self):
return (
self.x,
Expand All @@ -66,7 +70,7 @@ def inputs(self):
self.quant_min_bound,
)

def test_eval(self):
def compute(self):
inputs = self.inputs()
dy_out = self.func(*inputs)
static_func = paddle.jit.to_static(
Expand All @@ -75,11 +79,43 @@ def test_eval(self):
input_spec=None,
)(self.func)
st_out = static_func(*inputs)
return dy_out, st_out

def test_eval(self):
dy_out, st_out = self.compute()
for a, b in zip(
paddle.utils.flatten(dy_out), paddle.utils.flatten(st_out)
):
numpy.testing.assert_allclose(a, b, atol=1e-6, rtol=1e-6)


class TestFusedRmsNormQuantRint(TestFusedRmsNorm):
def modify_data(self):
self.quant_scale = 0.15
self.quant_round_type = 0
self.quant_max_bound = 127
self.quant_min_bound = -127

def test_eval(self):
# There is little precision difference after decomposition.
# which leads to different results after dequantization. So
# we skip this test.
self.compute()


class TestFusedRmsNormQuantRound(TestFusedRmsNorm):
def modify_data(self):
self.quant_scale = 0.15
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127

def test_eval(self):
# There is little precision difference after decomposition.
# which leads to different results after dequantization. So
# we skip this test.
self.compute()


if __name__ == '__main__':
unittest.main()