argwhere

paddle. argwhere ( input ) [源代码]

返回输入 x 中非零元素的坐标。如果输入 xn 维,共包含 z 个非零元素,返回结果是一个 shape 等于 [z x n]Tensor,第 i 行代表输入中第 i 个非零元素的坐标。

参数

  • input (Tensor)– 输入的 Tensor。

返回

  • Tensor(1-D Tensor),数据类型为 INT64

代码示例

>>> import paddle >>> x = paddle.to_tensor([[1.0, 0.0, 0.0], ...  [0.0, 2.0, 0.0], ...  [0.0, 0.0, 3.0]]) >>> out = paddle.tensor.search.argwhere(x) >>> print(out) Tensor(shape=[3, 2], dtype=int64, place=Place(cpu), stop_gradient=True, [[0, 0],  [1, 1],  [2, 2]])