@@ -3005,6 +3005,26 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
30053005 }
30063006};
30073007
3008+ //  rint(x) = [x]
3009+ template  <typename  T, typename  Enable = void >
3010+ struct  RintFunctor  : public  BaseActivationFunctor <T> {
3011+  template  <typename  Device, typename  X, typename  Out>
3012+  void  operator ()(Device d, X x, Out out) const  {
3013+  out.device (d) = x.unaryExpr ([](const  T& val) {
3014+  return  (std::isnan (val) || std::isinf (val)) ? val : std::rint (val);
3015+  });
3016+  }
3017+ };
3018+ 
3019+ template  <typename  T>
3020+ struct  RintFunctor <T, std::enable_if_t <std::is_integral_v<T>>>
3021+  : public BaseActivationFunctor<T> {
3022+  template  <typename  Device, typename  X, typename  Out>
3023+  void  operator ()(Device d, X x, Out out) const  {
3024+  out.device (d) = x;
3025+  }
3026+ };
3027+ 
30083028//  round(x) = [x]
30093029template  <typename  T, typename  Enable = void >
30103030struct  RoundFunctor  : public  BaseActivationFunctor <T> {
@@ -5410,6 +5430,25 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
54105430 }
54115431};
54125432
5433+ template  <typename  T, typename  Enable = void >
5434+ struct  CudaRintFunctor  : public  BaseActivationFunctor <T> {
5435+  using  MPType = typename  phi::dtype::MPTypeTrait<T>::Type;
5436+ 
5437+  //  rint(x) = rint(x)
5438+  __device__ __forceinline__ T operator ()(const  T arg_x) const  {
5439+  MPType x = static_cast <MPType>(arg_x);
5440+  if  (isnan (x) || isinf (x)) return  arg_x;
5441+  return  static_cast <T>(std::rint (x));
5442+  }
5443+ };
5444+ 
5445+ template  <typename  T>
5446+ struct  CudaRintFunctor <T, std::enable_if_t <std::is_integral_v<T>>>
5447+  : public BaseActivationFunctor<T> {
5448+  //  rint(x) = x
5449+  __device__ __forceinline__ T operator ()(const  T arg_x) const  { return  arg_x; }
5450+ };
5451+ 
54135452template  <typename  T, typename  Enable = void >
54145453struct  CudaRoundFunctor  : public  BaseActivationFunctor <T> {
54155454 using  MPType = typename  phi::dtype::MPTypeTrait<T>::Type;
0 commit comments