Skip to content

Commit 057df19

Browse files
authored
[Accuracy diff No.80、150] Fix accuracy diff for cumulative_trapezoid, trapezoid API (#73317)
1 parent 99fa5ab commit 057df19

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

paddle/phi/kernels/funcs/gather.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,14 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,
265265
for (int64_t i = 0; i < inner_dim_size; i++) {
266266
for (int64_t j = 0; j < input_index_dim_size; j++) {
267267
const int64_t index_data_j =
268-
(index_data[j] < 0 ? index_data[j] + input_index_dim_size
268+
(index_data[j] < 0 ? index_data[j] + out_index_dim_size
269269
: index_data[j]);
270270
for (int64_t k = 0; k < outer_dim_size; k++) {
271271
int64_t index = k + index_data_j * outer_dim_size +
272272
i * outer_dim_size * out_index_dim_size;
273-
out_data[index] += input_data[j * outer_dim_size + k];
273+
out_data[index] +=
274+
input_data[i * input_index_dim_size * outer_dim_size +
275+
j * outer_dim_size + k];
274276
}
275277
}
276278
}

test/legacy_test/test_gather_op.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import unittest
1617

1718
import numpy as np
@@ -908,6 +909,45 @@ def test_pir_out_type(self):
908909
self.assertTrue(out.dtype == core.DataType.INT64)
909910

910911

912+
class TestGatherBackward(unittest.TestCase):
913+
def setUp(self):
914+
self.shape = [10, 20]
915+
self.dtype = 'float32'
916+
self.index = (1, 3, 5)
917+
self.index_dtype = 'int64'
918+
self.places = []
919+
if (
920+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
921+
in ['1', 'true', 'on']
922+
or not paddle.is_compiled_with_cuda()
923+
):
924+
self.places.append(paddle.CPUPlace())
925+
if paddle.is_compiled_with_cuda():
926+
self.places.append(paddle.CUDAPlace(0))
927+
928+
def test_gather_backward(self):
929+
if len(self.places) != 2:
930+
return
931+
res_list = []
932+
x_np = np.random.random(self.shape).astype(self.dtype)
933+
index_np = np.array(self.index, dtype=self.index_dtype)
934+
grad_out_np = np.random.random(self.shape).astype(self.dtype)
935+
for place in self.places:
936+
with base.dygraph.guard(place):
937+
x = paddle.to_tensor(x_np, dtype=self.dtype)
938+
x.stop_gradient = False
939+
index = paddle.to_tensor(index_np, dtype=self.index_dtype)
940+
out = paddle.gather(x, index, -1)
941+
grad_out = paddle.to_tensor(grad_out_np, dtype=self.dtype)
942+
(re,) = paddle.grad(
943+
outputs=out,
944+
inputs=x,
945+
grad_outputs=grad_out,
946+
)
947+
res_list.append(re.numpy())
948+
np.testing.assert_allclose(res_list[0], res_list[1])
949+
950+
911951
if __name__ == "__main__":
912952
paddle.enable_static()
913953
unittest.main()

0 commit comments

Comments
 (0)