|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import os |
15 | 16 | import unittest |
16 | 17 |
|
17 | 18 | import numpy as np |
@@ -908,6 +909,45 @@ def test_pir_out_type(self): |
908 | 909 | self.assertTrue(out.dtype == core.DataType.INT64) |
909 | 910 |
|
910 | 911 |
|
| 912 | +class TestGatherBackward(unittest.TestCase): |
| 913 | + def setUp(self): |
| 914 | + self.shape = [10, 20] |
| 915 | + self.dtype = 'float32' |
| 916 | + self.index = (1, 3, 5) |
| 917 | + self.index_dtype = 'int64' |
| 918 | + self.places = [] |
| 919 | + if ( |
| 920 | + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() |
| 921 | + in ['1', 'true', 'on'] |
| 922 | + or not paddle.is_compiled_with_cuda() |
| 923 | + ): |
| 924 | + self.places.append(paddle.CPUPlace()) |
| 925 | + if paddle.is_compiled_with_cuda(): |
| 926 | + self.places.append(paddle.CUDAPlace(0)) |
| 927 | + |
| 928 | + def test_gather_backward(self): |
| 929 | + if len(self.places) != 2: |
| 930 | + return |
| 931 | + res_list = [] |
| 932 | + x_np = np.random.random(self.shape).astype(self.dtype) |
| 933 | + index_np = np.array(self.index, dtype=self.index_dtype) |
| 934 | + grad_out_np = np.random.random(self.shape).astype(self.dtype) |
| 935 | + for place in self.places: |
| 936 | + with base.dygraph.guard(place): |
| 937 | + x = paddle.to_tensor(x_np, dtype=self.dtype) |
| 938 | + x.stop_gradient = False |
| 939 | + index = paddle.to_tensor(index_np, dtype=self.index_dtype) |
| 940 | + out = paddle.gather(x, index, -1) |
| 941 | + grad_out = paddle.to_tensor(grad_out_np, dtype=self.dtype) |
| 942 | + (re,) = paddle.grad( |
| 943 | + outputs=out, |
| 944 | + inputs=x, |
| 945 | + grad_outputs=grad_out, |
| 946 | + ) |
| 947 | + res_list.append(re.numpy()) |
| 948 | + np.testing.assert_allclose(res_list[0], res_list[1]) |
| 949 | + |
| 950 | + |
911 | 951 | if __name__ == "__main__": |
912 | 952 | paddle.enable_static() |
913 | 953 | unittest.main() |
0 commit comments