Skip to content
19 changes: 16 additions & 3 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} break;
case proto::AttrType::SCALARS: {
case framework::proto::AttrType::SCALARS: {
const auto& vec = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, attr);
std::vector<phi::Scalar> scalar_list{vec.begin(), vec.end()};
Expand Down Expand Up @@ -805,8 +805,21 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(bool, attr));
break;
case phi::AttributeType::INT64:
infer_meta_context.EmplaceBackAttr(
PADDLE_GET_CONST(int64_t, attr));
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::LONG:
infer_meta_context.EmplaceBackAttr(
PADDLE_GET_CONST(int64_t, attr));
break;
case framework::proto::AttrType::INT: {
const auto val = PADDLE_GET_CONST(int, attr);
infer_meta_context.EmplaceBackAttr(static_cast<int64_t>(val));
} break;
default:
PADDLE_THROW(common::errors::Unimplemented(
"Unsupported cast op attribute `%s` to int64_t when "
"construct InferMetaContext.",
attr_names[i]));
}
break;
case phi::AttributeType::INT32S:
infer_meta_context.EmplaceBackAttr(
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3504,8 +3504,21 @@ void OperatorWithKernel::BuildPhiKernelContext(
PADDLE_GET_CONST(bool, attr_iter->second));
break;
case phi::AttributeType::INT64:
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(int64_t, attr_iter->second));
switch (AttrTypeID(attr_iter->second)) {
case proto::AttrType::LONG:
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(int64_t, attr_iter->second));
break;
case proto::AttrType::INT: {
const auto val = PADDLE_GET_CONST(int, attr_iter->second);
phi_kernel_context->EmplaceBackAttr(static_cast<int64_t>(val));
} break;
default:
PADDLE_THROW(common::errors::Unimplemented(
"Unsupported cast op attribute `%s` to int64_t when "
"construct KernelContext.",
attr_names[i]));
}
break;
case phi::AttributeType::INT32S: // NOLINT
phi_kernel_context->EmplaceBackAttr(
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/patch/0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
op_patches:
- op_name : pd_op.kthvalue
actions:
- action : modify_attr
object : k
type : pir::Int64Attribute
2 changes: 1 addition & 1 deletion paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -3154,7 +3154,7 @@ template <typename T>
void kthvalue_grad(const Tensor& x,
const Tensor& indices,
const Tensor& out_grad,
int k,
int64_t k UNUSED,
int axis,
bool keepdim,
Tensor* x_grad) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2486,7 +2486,7 @@ void IsfiniteInferMeta(const MetaTensor& x, MetaTensor* out) {
}

void KthvalueInferMeta(const MetaTensor& x,
int k,
int64_t k,
int axis,
bool keepdim,
MetaTensor* out,
Expand Down Expand Up @@ -2523,7 +2523,7 @@ void KthvalueInferMeta(const MetaTensor& x,
k,
1,
common::errors::InvalidArgument(
"the k in the kthvalue must >= 1, but received %d .", k));
"the k in the kthvalue must >= 1, but received %lld .", k));
PADDLE_ENFORCE_GE(input_dims.size(),
0,
common::errors::InvalidArgument(
Expand All @@ -2533,7 +2533,7 @@ void KthvalueInferMeta(const MetaTensor& x,
input_dims[axis],
k,
common::errors::InvalidArgument(
"input of kthvalue must have >= %d columns in axis of %d",
"input of kthvalue must have >= %lld columns in axis of %d",
k,
axis));
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out);
void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out);

void KthvalueInferMeta(const MetaTensor& x,
int k,
int64_t k,
int axis,
bool keepdim,
MetaTensor* out,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void KthvalueGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& d_out,
int k UNUSED,
int64_t k UNUSED,
int axis,
bool keepdim,
DenseTensor* d_x) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/kthvalue_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static void getKthvalue(Type input_height,
const DenseTensor* input,
T* t_out,
Type* t_indices,
const int& k) {
const int64_t& k) {
bool partial_sort_flag = (k * 64) < input_width;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
Expand Down Expand Up @@ -75,7 +75,7 @@ static void getKthvalue(Type input_height,
template <typename T, typename Context>
void KthvalueKernel(const Context& dev_ctx,
const DenseTensor& x,
int k,
int64_t k,
int axis,
bool keepdim,
DenseTensor* output,
Expand All @@ -98,7 +98,7 @@ void KthvalueKernel(const Context& dev_ctx,
1,
common::errors::InvalidArgument(
"the k in the kthvalue must less equal than the "
"elements number of the input X, but received %d .",
"elements number of the input X, but received %lld .",
k));

phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, output);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void KthvalueGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& d_out,
int k,
int64_t k,
int axis,
bool keepdim,
DenseTensor* d_x) {
Expand Down
7 changes: 2 additions & 5 deletions paddle/phi/kernels/gpu/kthvalue_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,11 @@ bool SortKthvalue(const phi::GPUContext& dev_ctx,
template <typename T, typename Context>
void KthvalueKernel(const Context& dev_ctx,
const DenseTensor& x,
int k,
int64_t k,
int axis,
bool keepdim,
DenseTensor* output,
DenseTensor* indices) {
// TODO(cangtianhuang): support int64_t k
k = static_cast<int64_t>(k);

if (x.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(output->dims())), NAN, output);
Expand All @@ -186,7 +183,7 @@ void KthvalueKernel(const Context& dev_ctx,
1,
common::errors::InvalidArgument(
"the k in the kthvalue must less equal than the "
"elements number of the input X, but received %d .",
"elements number of the input X, but received %lld .",
k));

phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, output);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/kthvalue_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void KthvalueGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& d_out,
int k,
int64_t k,
int axis,
bool keepdim,
DenseTensor* d_x);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/kthvalue_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace phi {
template <typename T, typename Context>
void KthvalueKernel(const Context& dev_ctx,
const DenseTensor& x,
int k,
int64_t k,
int axis,
bool keepdim,
DenseTensor* out,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1849,8 +1849,8 @@
data_type : out_grad

- backward_op : kthvalue_grad
forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim)
forward : kthvalue(Tensor x, int64_t k, int axis, bool keepdim) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int64_t k, int axis, bool keepdim)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/legacy/backward_exclude.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- fused_softmax_mask_grad
- fused_softmax_mask_upper_triangle_grad
- hsigmoid_loss_grad
- kthvalue_grad
- lp_pool2d_grad
- max_grad
- mean_double_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/legacy/ops_exclude.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
- gaussian
- hsigmoid_loss
- increment
- kthvalue
- linspace
- logspace
- lp_pool2d
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/ops/yaml/legacy/static_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@
param : [x, out_grad]
inplace : (out_grad -> x_grad)

- backward_op : kthvalue_grad
forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : kthvalue_grad
data_type : out_grad

- backward_op : legacy_bilinear_interp_grad
forward : legacy_bilinear_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_format="NCHW", int out_d=0, int out_h=0, int out_w=0, float scale=0.0, str interp_method="bilinear", bool align_corners=true, int align_mode=1) -> Tensor(output)
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, Tensor output_grad, str data_format, int out_d, int out_h, int out_w, float scale, str interp_method, bool align_corners, int align_mode)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/ops/yaml/legacy/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@
data_type: x
traits : paddle::dialect::ForwardOnlyTrait

- op : kthvalue
args : (Tensor x, int k = 1, int axis = -1, bool keepdim = false)
output : Tensor(out), Tensor(indices)
infer_meta :
func : KthvalueInferMeta
kernel :
func : kthvalue
backward : kthvalue_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface

- op : legacy_bilinear_interp
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_format="NCHW", int out_d=0, int out_h=0, int out_w=0, float scale=0.0, str interp_method="bilinear", bool align_corners=true, int align_mode=1)
output : Tensor(output)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2976,7 +2976,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : kthvalue
args : (Tensor x, int k = 1, int axis = -1, bool keepdim = false)
args : (Tensor x, int64_t k = 1, int axis = -1, bool keepdim = false)
output : Tensor(out), Tensor(indices)
infer_meta :
func : KthvalueInferMeta
Expand Down