Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@
from .tensor.stat import var # noqa: F401
from .tensor.stat import numel # noqa: F401
from .tensor.stat import median # noqa: F401
from .tensor.stat import quantile # noqa: F401
from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401
from .device import get_device # noqa: F401
Expand Down Expand Up @@ -478,6 +479,7 @@
'load',
'numel',
'median',
'quantile',
'no_grad',
'set_grad_enabled',
'is_grad_enabled',
Expand Down
150 changes: 150 additions & 0 deletions python/paddle/fluid/tests/unittests/test_quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle


class TestQuantile(unittest.TestCase):
def setUp(self):
np.random.seed(678)
self.input_data = np.random.rand(6, 7, 8, 9, 10)

def test_quantile_single_q(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.5, axis=2)
np_res = np.quantile(self.input_data, q=0.5, axis=2)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_with_no_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.35)
np_res = np.quantile(self.input_data, q=0.35)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_with_multi_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.75, axis=[0, 2, 3])
np_res = np.quantile(self.input_data, q=0.75, axis=[0, 2, 3])
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_with_keepdim(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.35, axis=4, keepdim=True)
np_res = np.quantile(self.input_data, q=0.35, axis=4, keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_with_keepdim_and_multiple_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.1, axis=[1, 4], keepdim=True)
np_res = np.quantile(self.input_data, q=0.1, axis=[1, 4], keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_with_boundary_q(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0, axis=3)
np_res = np.quantile(self.input_data, q=0, axis=3)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_include_NaN(self):
input_data = np.random.randn(2, 3, 4)
input_data[0, 1, 1] = np.nan
x = paddle.to_tensor(input_data)
paddle_res = paddle.quantile(x, q=0.35, axis=0)
self.assertTrue(paddle.isnan(paddle_res[1, 1]))


class TestQuantileMuitlpleQ(unittest.TestCase):
def setUp(self):
np.random.seed(678)
self.input_data = np.random.rand(10, 3, 4, 5, 4)

def test_quantile(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2)
np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_multiple_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1])
np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1])
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))

def test_quantile_multiple_axis_keepdim(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(
x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True)
np_res = np.quantile(
self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))


class TestQuantileError(unittest.TestCase):
def setUp(self):
self.x = paddle.randn((2, 3, 4))

def test_errors(self):
def test_q_range_error_1():
paddle_res = paddle.quantile(self.x, q=1.5)

self.assertRaises(ValueError, test_q_range_error_1)

def test_q_range_error_2():
paddle_res = paddle.quantile(self.x, q=[0.2, -0.3])

self.assertRaises(ValueError, test_q_range_error_2)

def test_q_range_error_3():
paddle_res = paddle.quantile(self.x, q=[])

self.assertRaises(ValueError, test_q_range_error_3)

def test_x_type_error():
x = [1, 3, 4]
paddle_res = paddle.quantile(x, q=0.9)

self.assertRaises(TypeError, test_x_type_error)

def test_axis_type_error_1():
paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4)

self.assertRaises(ValueError, test_axis_type_error_1)

def test_axis_type_error_2():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4])

self.assertRaises(ValueError, test_axis_type_error_2)

def test_axis_value_error_1():
paddle_res = paddle.quantile(self.x, q=0.4, axis=10)

self.assertRaises(ValueError, test_axis_value_error_1)

def test_axis_value_error_2():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10])

self.assertRaises(ValueError, test_axis_value_error_2)

def test_axis_value_error_3():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[])

self.assertRaises(ValueError, test_axis_value_error_3)


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@
from .stat import var # noqa: F401
from .stat import numel # noqa: F401
from .stat import median # noqa: F401
from .stat import quantile # noqa: F401

from .to_string import set_printoptions # noqa: F401

from .array import array_length # noqa: F401
Expand Down Expand Up @@ -429,6 +431,7 @@
'var',
'numel',
'median',
'quantile',
'is_complex',
'is_integer',
'rank',
Expand Down
124 changes: 124 additions & 0 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,127 @@ def median(x, axis=None, keepdim=False, name=None):
newshape = out_tensor.shape
out_tensor = out_tensor.reshape(newshape, name=name)
return out_tensor


def quantile(x, q, axis=None, keepdim=False):
"""
Compute the quantile of the input along the specified axis.

Args:
x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is a list, quantile is calculated over all elements of given axises.
If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, results of quantile along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32.

Examples:
.. code-block:: python

import paddle

x = paddle.randn((2,3))
#[[-1.28740597, 0.49533170, -1.00698614],
# [-1.11656201, -1.01010525, -2.23457789]])

y1 = paddle.quantile(x, q=0.5, axis=[0, 1])
# y1 = -1.06333363

y2 = paddle.quantile(x, q=0.5, axis=1)
# y2 = [-1.00698614, -1.11656201]

y3 = paddle.quantile(x, q=[0.3, 0.5], axis=1)
# y3 =[[-1.11915410, -1.56376839],
# [-1.00698614, -1.11656201]]

y4 = paddle.quantile(x, q=0.8, axis=1, keepdim=True)
# y4 = [[-0.10559537],
# [-1.05268800]])
"""
if not isinstance(x, Variable):
raise TypeError("input x should be a Tensor.")
dims = len(x.shape)
out_shape = x.shape
if axis is None:
x = paddle.flatten(x)
axis = 0
out_shape = [1] * dims
else:
if isinstance(axis, list):
if (len(axis) <= 0):
raise ValueError("axis should not be empty")
axis_src, axis_dst = [], []
for axis_single in axis:
if not isinstance(axis_single, int) or not (
axis_single < dims and axis_single >= -dims):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis_single < 0:
axis_single = axis_single + dims
axis_src.append(axis_single)
out_shape[axis_single] = 1
axis_dst = list(range(-len(axis), 0))
x = paddle.moveaxis(x, axis_src, axis_dst)
x = paddle.flatten(x, axis_dst[0], axis_dst[-1])
axis = axis_dst[0]
else:
if not isinstance(axis, int) or not (axis < dims and axis >= -dims):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis < 0:
axis += dims
out_shape[axis] = 1
indices = []
if isinstance(q, (int, float)):
if q < 0 or q > 1:
raise ValueError("q should be in range [0, 1]")
indices.append(q * (x.shape[axis] - 1))
elif isinstance(q, (list, tuple)):
if len(q) <= 0:
raise ValueError("q should not be empty")
for q_num in q:
if q_num < 0 or q_num > 1:
raise ValueError("q should be in range [0, 1]")
indices.append(q_num * (x.shape[axis] - 1))
else:
raise TypeError("Type of q should be int, float, list or tuple.")
indices = paddle.to_tensor(indices).astype(paddle.float32)
sorted_tensor = paddle.sort(x, axis)
indices_below = paddle.floor(indices).astype(paddle.int32)
indices_upper = paddle.ceil(indices).astype(paddle.int32)
outputs = []

# TODO(chenjianye): replace the for-loop to directly take elements.
for i in range(len(indices)):
if (indices_upper[i] != indices_below[i]):
tensor_below = paddle.take_along_axis(sorted_tensor,
indices_below[i], axis)
tensor_upper = paddle.take_along_axis(sorted_tensor,
indices_upper[i], axis)
weights = (indices[i] - indices_below[i]).astype(x.dtype)
out = paddle.lerp(tensor_below, tensor_upper, weights)
else:
out = paddle.take_along_axis(sorted_tensor, indices_below[i], axis)
Comment on lines +442 to +450
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.take_along_axis should support list of index, if there is something wrong, add to do and issue it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, add a todo

if not keepdim:
out = paddle.squeeze(out, axis=axis)
else:
out = out.reshape(out_shape)
outputs.append(out)
if isinstance(q, (list, tuple)):
return paddle.stack(outputs, 0)
else:
return outputs[0]