Skip to content

Commit fba4675

Browse files
authored
【Hackathon 6th No.11】为 Paddle 新增 bernoulli_ API - part (#64252)
* add bernoulli_ * update * update docs * fix typo * update code and test * add test case for backward * add test broadcast error and update docs
1 parent fb432fb commit fba4675

File tree

4 files changed

+148
-0
lines changed

4 files changed

+148
-0
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@
516516
)
517517
from .tensor.random import (
518518
bernoulli,
519+
bernoulli_,
519520
binomial,
520521
check_shape,
521522
multinomial,
@@ -808,6 +809,7 @@
808809
'expm1',
809810
'expm1_',
810811
'bernoulli',
812+
'bernoulli_',
811813
'binomial',
812814
'poisson',
813815
'standard_gamma',

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@
402402
vander,
403403
)
404404
from .random import ( # noqa: F401
405+
bernoulli_,
405406
binomial,
406407
exponential_,
407408
multinomial,
@@ -751,6 +752,7 @@
751752
'put_along_axis',
752753
'select_scatter',
753754
'put_along_axis_',
755+
'bernoulli_',
754756
'exponential_',
755757
'heaviside',
756758
'index_add',

python/paddle/tensor/random.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,58 @@ def bernoulli(x, name=None):
104104
return out
105105

106106

107+
@dygraph_only
108+
def bernoulli_(x, p=0.5, name=None):
109+
"""
110+
This is the inplace version of api ``bernoulli``, which returns a Tensor filled
111+
with random values sampled from a bernoulli distribution. The output Tensor will
112+
be inplaced with input ``x``. Please refer to :ref:`api_tensor_bernoulli`.
113+
114+
Args:
115+
x(Tensor): The input tensor to be filled with random values.
116+
p (float|Tensor, optional): The success probability parameter of the output Tensor's bernoulli distribution.
117+
If ``p`` is float, all elements of the output Tensor shared the same success probability.
118+
If ``p`` is a Tensor, it has per-element success probabilities, and the shape should be broadcastable to ``x``.
119+
Default is 0.5
120+
name(str, optional): The default value is None. Normally there is no
121+
need for user to set this property. For more information, please
122+
refer to :ref:`api_guide_Name`.
123+
124+
Returns:
125+
A Tensor filled with random values sampled from the bernoulli distribution with success probability ``p`` .
126+
127+
Examples:
128+
.. code-block:: python
129+
130+
>>> import paddle
131+
>>> x = paddle.randn([3, 4])
132+
>>> x.bernoulli_()
133+
>>> # doctest: +SKIP('random check')
134+
>>> print(x)
135+
Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
136+
[[1., 0., 1., 0.],
137+
[0., 1., 1., 0.],
138+
[0., 1., 1., 1.]])
139+
140+
>>> x = paddle.randn([3, 4])
141+
>>> p = paddle.randn([3, 1])
142+
>>> x.bernoulli_(p)
143+
>>> # doctest: +SKIP('random check')
144+
>>> print(x)
145+
Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
146+
[[1., 0., 0., 1.],
147+
[1., 0., 1., 0.],
148+
[1., 0., 0., 0.]])
149+
150+
"""
151+
x.uniform_(0.0, 1.0)
152+
ones_mask = x > p
153+
zeros_mask = x < p
154+
x.masked_fill_(ones_mask, 1.0)
155+
x.masked_fill_(zeros_mask, 0.0)
156+
return x
157+
158+
107159
def binomial(count, prob, name=None):
108160
r"""
109161
Returns a tensor filled with random number from the Binomial Distribution, which supports Tensor shape

test/deprecated/legacy_test/test_inplace.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,5 +1921,97 @@ def test_inplace_api(self):
19211921
)
19221922

19231923

1924+
class TestDygraphInplaceBernoulli(unittest.TestCase):
1925+
def setUp(self):
1926+
self.init_data()
1927+
self.set_np_compare_func()
1928+
1929+
def init_data(self):
1930+
self.shape = (100, 1000)
1931+
self.input_var_numpy = np.random.random(self.shape)
1932+
self.dtype = "float32"
1933+
self.p = 0.5
1934+
1935+
def set_np_compare_func(self):
1936+
self.np_compare = np.array_equal
1937+
1938+
def inplace_api_processing(self, var):
1939+
return paddle.bernoulli_(var, p=self.p)
1940+
1941+
def inplace_class_method_processing(self, var):
1942+
return var.bernoulli_(self.p)
1943+
1944+
def non_inplace_api_processing(self):
1945+
return paddle.bernoulli(paddle.full(self.shape, self.p))
1946+
1947+
def test_inplace_api(self):
1948+
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
1949+
non_inplace_var = self.non_inplace_api_processing()
1950+
inplace_var = self.inplace_api_processing(var)
1951+
self.assertTrue(id(var) == id(inplace_var))
1952+
np.testing.assert_allclose(
1953+
non_inplace_var.numpy().mean(),
1954+
inplace_var.numpy().mean(),
1955+
atol=0.01,
1956+
)
1957+
np.testing.assert_allclose(
1958+
non_inplace_var.numpy().var(), inplace_var.numpy().var(), atol=0.01
1959+
)
1960+
1961+
def test_inplace_api_backward(self):
1962+
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
1963+
var_a.stop_gradient = False
1964+
var_b = var_a.clone()
1965+
expected_gradient = np.zeros(self.shape)
1966+
inplace_var = self.inplace_api_processing(var_b)
1967+
inplace_var.backward()
1968+
np.testing.assert_equal(
1969+
var_a.grad.numpy(),
1970+
expected_gradient,
1971+
)
1972+
1973+
def test_inplace_class_method(self):
1974+
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
1975+
non_inplace_var = self.non_inplace_api_processing()
1976+
inplace_var = self.inplace_class_method_processing(var)
1977+
self.assertTrue(id(var) == id(inplace_var))
1978+
np.testing.assert_allclose(
1979+
non_inplace_var.numpy().mean(),
1980+
inplace_var.numpy().mean(),
1981+
atol=0.01,
1982+
)
1983+
np.testing.assert_allclose(
1984+
non_inplace_var.numpy().var(), inplace_var.numpy().var(), atol=0.01
1985+
)
1986+
1987+
def test_inplace_class_method_backward(self):
1988+
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
1989+
var_a.stop_gradient = False
1990+
var_b = var_a.clone()
1991+
expected_gradient = np.zeros(self.shape)
1992+
inplace_var = self.inplace_class_method_processing(var_b)
1993+
inplace_var.backward()
1994+
np.testing.assert_equal(
1995+
var_a.grad.numpy(),
1996+
expected_gradient,
1997+
)
1998+
1999+
2000+
class TestDygraphInplaceBernoulli2(TestDygraphInplaceBernoulli):
2001+
def init_data(self):
2002+
self.shape = (100, 1000)
2003+
self.input_var_numpy = np.random.random(self.shape)
2004+
self.dtype = "float64"
2005+
self.p = 0.5
2006+
2007+
2008+
class TestDygraphInplaceBernoulliError(unittest.TestCase):
2009+
def test_broadcast_error(self):
2010+
var = paddle.randn([3, 4])
2011+
p = paddle.randn([5])
2012+
with self.assertRaises(ValueError):
2013+
var.bernoulli_(p)
2014+
2015+
19242016
if __name__ == '__main__':
19252017
unittest.main()

0 commit comments

Comments
 (0)