|  | 
|  | 1 | +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +# http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import unittest | 
|  | 16 | + | 
|  | 17 | +import numpy as np | 
|  | 18 | +from op_test import OpTest, convert_float_to_uint16 | 
|  | 19 | + | 
|  | 20 | +import paddle | 
|  | 21 | +from paddle import base | 
|  | 22 | +from paddle.base import Program, program_guard | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +def call_argwhere(x): | 
|  | 26 | + input = paddle.to_tensor(x) | 
|  | 27 | + return paddle.argwhere(input) | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +class TestArgwhereAPI(unittest.TestCase): | 
|  | 31 | + def test_argwhere_api(self): | 
|  | 32 | + paddle.enable_static() | 
|  | 33 | + data = np.array([[1, 0], [0, 1]], dtype="float32") | 
|  | 34 | + with program_guard(Program(), Program()): | 
|  | 35 | + x = paddle.static.data(name='x', shape=[-1, 2], dtype='float32') | 
|  | 36 | + if not paddle.framework.use_pir_api(): | 
|  | 37 | + x.desc.set_need_check_feed(False) | 
|  | 38 | + y = paddle.argwhere(x) | 
|  | 39 | + exe = base.Executor(base.CPUPlace()) | 
|  | 40 | + (res,) = exe.run( | 
|  | 41 | + feed={'x': data}, fetch_list=[y], return_numpy=False | 
|  | 42 | + ) | 
|  | 43 | + expect_out = np.array([[0, 0], [1, 1]]) | 
|  | 44 | + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) | 
|  | 45 | + | 
|  | 46 | + data = np.array([1, 1, 0], dtype="float32") | 
|  | 47 | + with program_guard(Program(), Program()): | 
|  | 48 | + x = paddle.static.data(name='x', shape=[-1], dtype='float32') | 
|  | 49 | + if not paddle.framework.use_pir_api(): | 
|  | 50 | + x.desc.set_need_check_feed(False) | 
|  | 51 | + y = paddle.argwhere(x) | 
|  | 52 | + exe = base.Executor(base.CPUPlace()) | 
|  | 53 | + (res,) = exe.run( | 
|  | 54 | + feed={'x': data}, fetch_list=[y], return_numpy=False | 
|  | 55 | + ) | 
|  | 56 | + expect_out = np.array([[0], [1]]) | 
|  | 57 | + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) | 
|  | 58 | + | 
|  | 59 | + def test_dygraph_api(self): | 
|  | 60 | + data_x = np.array([[True, False], [False, True]]) | 
|  | 61 | + with base.dygraph.guard(): | 
|  | 62 | + x = paddle.to_tensor(data_x) | 
|  | 63 | + z = paddle.argwhere(x) | 
|  | 64 | + np_z = z.numpy() | 
|  | 65 | + expect_out = np.array([[0, 0], [1, 1]]) | 
|  | 66 | + | 
|  | 67 | + | 
|  | 68 | +# Base case | 
|  | 69 | +class TestArgwhereOp(OpTest): | 
|  | 70 | + def setUp(self): | 
|  | 71 | + '''Test where_index op with random value''' | 
|  | 72 | + np.random.seed(2023) | 
|  | 73 | + self.op_type = "where_index" | 
|  | 74 | + self.python_api = call_argwhere | 
|  | 75 | + self.init_shape() | 
|  | 76 | + self.init_dtype() | 
|  | 77 | + | 
|  | 78 | + self.inputs = self.create_inputs() | 
|  | 79 | + self.outputs = self.return_outputs() | 
|  | 80 | + | 
|  | 81 | + def test_check_output(self): | 
|  | 82 | + self.check_output(check_pir=True, check_symbol_infer=False) | 
|  | 83 | + | 
|  | 84 | + def init_shape(self): | 
|  | 85 | + self.shape = [8, 8] | 
|  | 86 | + | 
|  | 87 | + def init_dtype(self): | 
|  | 88 | + self.dtype = np.float64 | 
|  | 89 | + | 
|  | 90 | + def create_inputs(self): | 
|  | 91 | + return { | 
|  | 92 | + 'Condition': np.random.randint(5, size=self.shape).astype( | 
|  | 93 | + self.dtype | 
|  | 94 | + ) | 
|  | 95 | + } | 
|  | 96 | + | 
|  | 97 | + def return_outputs(self): | 
|  | 98 | + return {'Out': np.argwhere(self.inputs['Condition'])} | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +class TestArgwhereComplex64Op(TestArgwhereOp): | 
|  | 102 | + def init_shape(self): | 
|  | 103 | + self.shape = [1, 2, 3] | 
|  | 104 | + | 
|  | 105 | + def init_dtype(self): | 
|  | 106 | + self.dtype = np.complex64 | 
|  | 107 | + | 
|  | 108 | + | 
|  | 109 | +class TestArgwhereComplex128Op(TestArgwhereOp): | 
|  | 110 | + def init_shape(self): | 
|  | 111 | + self.shape = [1, 2, 3] | 
|  | 112 | + | 
|  | 113 | + def init_dtype(self): | 
|  | 114 | + self.dtype = np.complex128 | 
|  | 115 | + | 
|  | 116 | + | 
|  | 117 | +class TestArgwhereFP32Op(TestArgwhereOp): | 
|  | 118 | + def init_shape(self): | 
|  | 119 | + self.shape = [2, 10, 2] | 
|  | 120 | + | 
|  | 121 | + def init_dtype(self): | 
|  | 122 | + self.dtype = np.float32 | 
|  | 123 | + | 
|  | 124 | + | 
|  | 125 | +class TestArgwhereFP16Op(TestArgwhereOp): | 
|  | 126 | + def init_shape(self): | 
|  | 127 | + self.shape = [3, 4, 7] | 
|  | 128 | + | 
|  | 129 | + def init_dtype(self): | 
|  | 130 | + self.dtype = np.float16 | 
|  | 131 | + | 
|  | 132 | + | 
|  | 133 | +class TestArgwhereBF16(OpTest): | 
|  | 134 | + def setUp(self): | 
|  | 135 | + '''Test where_index op with bfloat16 dtype''' | 
|  | 136 | + np.random.seed(2023) | 
|  | 137 | + self.op_type = "where_index" | 
|  | 138 | + self.python_api = call_argwhere | 
|  | 139 | + self.init_shape() | 
|  | 140 | + self.init_dtype() | 
|  | 141 | + | 
|  | 142 | + self.inputs = self.create_inputs() | 
|  | 143 | + self.outputs = self.return_outputs() | 
|  | 144 | + | 
|  | 145 | + def test_check_output(self): | 
|  | 146 | + self.check_output(check_pir=True, check_symbol_infer=False) | 
|  | 147 | + | 
|  | 148 | + def init_shape(self): | 
|  | 149 | + self.shape = [12, 9] | 
|  | 150 | + | 
|  | 151 | + def init_dtype(self): | 
|  | 152 | + self.dtype = np.uint16 | 
|  | 153 | + | 
|  | 154 | + def create_inputs(self): | 
|  | 155 | + return { | 
|  | 156 | + 'Condition': convert_float_to_uint16( | 
|  | 157 | + np.random.randint(5, size=self.shape).astype(np.float32) | 
|  | 158 | + ) | 
|  | 159 | + } | 
|  | 160 | + | 
|  | 161 | + def return_outputs(self): | 
|  | 162 | + return {'Out': np.argwhere(self.inputs['Condition'])} | 
|  | 163 | + | 
|  | 164 | + | 
|  | 165 | +class TestZeroSizeOp(TestArgwhereOp): | 
|  | 166 | + | 
|  | 167 | + def init_shape(self): | 
|  | 168 | + self.shape = [0, 10] | 
|  | 169 | + | 
|  | 170 | + def init_dtype(self): | 
|  | 171 | + self.dtype = np.float64 | 
|  | 172 | + | 
|  | 173 | + | 
|  | 174 | +class TestZeroSizeOpCase2(TestArgwhereOp): | 
|  | 175 | + | 
|  | 176 | + def init_shape(self): | 
|  | 177 | + self.shape = [0, 10] | 
|  | 178 | + | 
|  | 179 | + def init_dtype(self): | 
|  | 180 | + self.dtype = np.float64 | 
|  | 181 | + | 
|  | 182 | + def test_check_output(self): | 
|  | 183 | + self.check_output(check_pir=True, check_symbol_infer=True) | 
|  | 184 | + | 
|  | 185 | + | 
|  | 186 | +if __name__ == "__main__": | 
|  | 187 | + unittest.main() | 
0 commit comments