Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 23 additions & 104 deletions paddle/cinn/hlir/op/contrib/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,46 +42,21 @@ namespace op {
using cinn::common::CINNValue;
using cinn::common::CINNValuePack;

ir::Tensor GatherNd(const ir::Tensor &x,
const ir::Tensor &index,
const std::string &name) {
std::vector<Expr> x_shape = x->shape;
std::vector<Expr> index_shape = index->shape;
size_t x_shape_size = x_shape.size();
size_t index_shape_size = index_shape.size();
std::vector<Expr> out_shape;
out_shape.insert(out_shape.end(), index_shape.begin(), index_shape.end() - 1);
out_shape.insert(out_shape.end(),
x_shape.begin() + index_shape.back().as_int32(),
x_shape.end());
auto res = Compute(
out_shape,
[=](const std::vector<Expr> &indices) {
std::vector<Expr> indices_position;
for (size_t i = 0; i < index_shape_size - 1; ++i) {
indices_position.push_back(
ir::Cast::Make(cinn::common::Int(32), indices[i]));
}
indices_position.push_back(
ir::Cast::Make(cinn::common::Int(32), Expr(0)));
size_t indices_position_size = indices_position.size();
std::vector<Expr> real_indices;
for (size_t i = 0; i < index_shape.back().as_int32(); ++i) {
indices_position[indices_position_size - 1] =
ir::Cast::Make(cinn::common::Int(32), Expr(i));
real_indices.push_back(
ir::Cast::Make(cinn::common::Int(32), index(indices_position)));
}
if (real_indices.size() == x_shape_size) {
return x(real_indices);
}
for (size_t i = index_shape_size - 1; i < indices.size(); ++i) {
real_indices.push_back(indices[i]);
}
return x(real_indices);
},
name);
return res;
Expr CastShapeElementType(Expr shape_element) {
// constant values should already been cast to int64_t for shape calc
// make sure shape expr does not contain non-matched int type for min and max
if (shape_element.is_constant()) return shape_element;
if (auto max_elem = shape_element.As<ir::Max>()) {
max_elem->a() = CastShapeElementType(max_elem->a());
max_elem->b() = CastShapeElementType(max_elem->b());
} else if (auto min_elem = shape_element.As<ir::Min>()) {
min_elem->a() = CastShapeElementType(min_elem->a());
min_elem->b() = CastShapeElementType(min_elem->b());
} else {
shape_element = ir::Call::Make(
ir::Int(64), "int64_t", {shape_element}, {}, ir::CallType::Intrinsic);
}
return shape_element;
}

ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
Expand Down Expand Up @@ -111,8 +86,15 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
for (size_t i = 0; i < index_shape.back().as_int64(); ++i) {
indices_position[indices_position_size - 1] =
ir::Cast::Make(cinn::common::Int(64), Expr(i));
// support negative indices
auto idx_expr =
ir::Cast::Make(cinn::common::Int(64), index(indices_position));
auto real_expr =
ir::Select::Make(ir::GE::Make(idx_expr, Expr(0)),
idx_expr,
CastShapeElementType(x_shape[i]) + idx_expr);
real_indices.push_back(
ir::Cast::Make(cinn::common::Int(64), index(indices_position)));
ir::Cast::Make(cinn::common::Int(64), real_expr));
}
if (real_indices.size() == x_shape_size) {
return x(real_indices);
Expand All @@ -126,67 +108,6 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
std::string op_name("gather_nd");

framework::CINNCompute gather_nd_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_NE(
args.empty(),
true,
::common::errors::InvalidArgument(
"The input argument of %s compute is empty! Please check.",
op_name));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_GE(
pack_args.size(),
2U,
::common::errors::InvalidArgument("2 input tensors for compute\n"));
Expr x = pack_args[0];
Expr index = pack_args[1];
PADDLE_ENFORCE_NOT_NULL(x.as_tensor(),
::common::errors::InvalidArgument(
"Required x must be a tensor. Please check."));
PADDLE_ENFORCE_NOT_NULL(
index.as_tensor(),
::common::errors::InvalidArgument(
"Required index must be a tensor. Please check."));
PADDLE_ENFORCE_NE(
output_shapes.empty(),
true,
::common::errors::InvalidArgument(
"The output shape of gather_nd is empty! Please check."));
auto tensor_x = x.as_tensor_ref();
auto tensor_index = index.as_tensor_ref();
VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ")
<< ", index shape: " << utils::Join(tensor_index->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
PADDLE_ENFORCE_EQ(pack_args.size(),
3U,
::common::errors::InvalidArgument(
"The size of pack_args should be 3\n"));
std::string tensor_name = pack_args[2].operator std::string();
ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name);
std::vector<CINNValue> res;
res.push_back(CINNValue(out));
PADDLE_ENFORCE_NE(
out_type.empty(),
true,
::common::errors::InvalidArgument(
"The output type of gather_nd is empty! Please check."));
*ret = CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(gather_nd_compute, "strategy.gather_nd.x86", 1);
return strategy;
}

std::shared_ptr<framework::OpStrategy> StrategyForGatherNdSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
Expand Down Expand Up @@ -259,8 +180,6 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForGatherNdSymbolic)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForGatherNd)
.set_attr<cinn::hlir::framework::OpPatternKind>(
"OpPattern", cinn::hlir::framework::OpPatternKind::kInjective)
.set_support_level(4);
Expand Down
126 changes: 126 additions & 0 deletions test/ir/pir/cinn/test_cinn_gather_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy
import utils

import paddle


class TestGatherNd(unittest.TestCase):
# Note that GatherNd is also used in index_put, so we can test it by using index_put.
def eval(self, dy_compute, inputs, input_spec=None):
dy_out = dy_compute(*inputs)

static_compute = utils.apply_to_static(
dy_compute, use_cinn=True, input_spec=None
)
st_out = static_compute(*inputs)

for a, b in zip(
paddle.utils.flatten(dy_out), paddle.utils.flatten(st_out)
):
numpy.testing.assert_allclose(a, b, atol=1e-6, rtol=1e-6)

@staticmethod
def get_input(x_shape, indices_shape, value_shape, has_negative_index=True):
n_indices = indices_shape[0]
index_dim_size = indices_shape[1] if len(indices_shape) > 1 else 1

x_pd = paddle.randn(x_shape)
x_pd.stop_gradient = False

indices_pd = tuple(
[
paddle.randint(
-x_shape[i] if has_negative_index else 0,
x_shape[i],
[n_indices],
)
for i in range(max(index_dim_size, 1))
]
)
value_pd = paddle.randn(value_shape)
value_pd.stop_gradient = False

dout_pd = paddle.randn(x_shape)
dout_pd.stop_gradient = False
return x_pd, indices_pd, value_pd, dout_pd

@staticmethod
def get_input_spec(indice_dim):
return [
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
tuple(
paddle.static.InputSpec(shape=[-1], dtype="int64")
for _ in range(indice_dim)
),
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
]

@staticmethod
def index_put_grad(x, indices, v, dy):
y = paddle.index_put(x, indices, v, True)
return paddle.grad(y, [x, v], dy)

def test_index_put_grad_non_negative_index(self):
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
[12, 13, 14], [88, 2], [88, 14], False
)

self.eval(
TestGatherNd.index_put_grad,
[x_pd, indices_pd, value_pd, dout_pd],
input_spec=self.get_input_spec(2),
)

def test_index_put_grad_negative_index_1(self):
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
[12, 13, 14], [88, 1], [88, 13, 14]
)

self.eval(
TestGatherNd.index_put_grad,
[x_pd, indices_pd, value_pd, dout_pd],
input_spec=self.get_input_spec(1),
)

def test_index_put_grad_negative_index_2(self):
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
[16, 16], [20, 2], [20]
)

self.eval(
TestGatherNd.index_put_grad,
[x_pd, indices_pd, value_pd, dout_pd],
input_spec=self.get_input_spec(2),
)

def test_gather_nd_fusion(self):
x_pd = paddle.randn([256, 128])
y_pd = paddle.randn_like(x_pd)
z_pd = paddle.randn([100])
indices_pd = paddle.randint(-128, 128, [100, 2])

def func(x, y, z, indices):
return paddle.gather_nd(x * y, indices) + z

self.eval(func, [x_pd, y_pd, z_pd, indices_pd])


if __name__ == "__main__":
unittest.main()