Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,57 +147,6 @@ void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
const Scalar& axis,
Tensor* grad_x) {
auto zero_tensor =
full<T>(common::vectorize(x.dims()), 0.0, x.dtype(), x.place());
std::vector<int> tmp_perm;

// change axis to rank 0
int axis_value = axis.to<int>();
int rank = x.dims().size();
if (axis_value < 0) {
axis_value += rank;
}
tmp_perm.push_back(axis_value);
// make other ranks
for (int i = 0; i < rank; ++i) {
if (i != axis_value) {
tmp_perm.push_back(i);
}
}
std::vector<int> reverse_perm(tmp_perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
if (tmp_perm[i] >= 0) {
reverse_perm[tmp_perm[i]] = i;
} else {
reverse_perm[tmp_perm[i] + tmp_perm.size()] = i;
}
}

// transpose out_grad and zero grad to target rank.
auto tmp_zero_x_grad = zero_tensor;
auto tmp_out_grad = out_grad;
if (zero_tensor.dims().size() > 0) {
tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
}
if (out_grad.dims().size() > 0) {
tmp_out_grad = transpose<T>(out_grad, tmp_perm);
}
// scatter grad to grad_x
auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
auto tmp_grad_x_transposed = tmp_grad_x;
if (tmp_grad_x.dims().size() > 0) {
tmp_grad_x_transposed = transpose<T>(tmp_grad_x, reverse_perm);
}
set_output<T>(tmp_grad_x_transposed, grad_x);
}

template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,12 @@
kernel :
func : gammaln_grad

- backward_op : gather_double_grad
forward : gather_grad(Tensor x, Tensor index, Tensor grad_out, Scalar axis=0) -> Tensor(grad_x)
args : (Tensor index, Tensor grad_x_grad, Scalar axis)
output : Tensor(grad_out_grad)
invoke: gather(grad_x_grad, index, axis)

- backward_op : gather_grad
forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0)
Expand All @@ -1398,8 +1404,8 @@
kernel :
data_type: out_grad
func : gather_grad
composite : gather_grad(x, index, out_grad, axis, x_grad)
no_need_buffer : x
backward : gather_double_grad

- backward_op : gather_nd_double_grad
forward : gather_nd_grad (Tensor x, Tensor index, Tensor grad_out) -> Tensor(grad_x)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,7 @@
out : Out

- op : gather
backward : gather_grad
backward : gather_grad, gather_double_grad
inputs :
{x : X, index : Index}
outputs :
Expand Down
6 changes: 1 addition & 5 deletions test/legacy_test/test_gather_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def test_check_output(self):
self.check_output(check_pir=True, check_symbol_infer=False)

def test_check_grad(self):
self.check_grad(
['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True
)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)

def config(self):
"""
Expand Down Expand Up @@ -131,7 +129,6 @@ def test_check_grad(self):
paddle.CUDAPlace(0),
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
Expand Down Expand Up @@ -703,7 +700,6 @@ def test_check_grad(self):
['X'],
'Out',
check_pir=True,
check_prim=True,
check_prim_pir=True,
)

Expand Down
16 changes: 5 additions & 11 deletions test/legacy_test/test_index_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,15 @@ def init_dtype_type(self):

def test_check_output(self):
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_output(
check_prim=False, check_pir=True, check_prim_pir=False
)
self.check_output(check_pir=True, check_prim_pir=False)
else:
self.check_output(
check_prim=True, check_pir=True, check_prim_pir=True
)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad_normal(self):
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False, check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True)
else:
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True)


class TestIndexSelectOpCase2(TestIndexSelectOp):
Expand Down Expand Up @@ -223,9 +219,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_pir=True
)
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


class TestIndexSelectComplex64(TestIndexSelectOp):
Expand Down