Skip to content

Commit a6efaf2

Browse files
committed
move distribution.py into distribution package and split into different file for better scalability
1 parent 22f14e7 commit a6efaf2

File tree

14 files changed

+2451
-2267
lines changed

14 files changed

+2451
-2267
lines changed

python/paddle/distribution.py

Lines changed: 0 additions & 968 deletions
This file was deleted.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .categorical import Categorical
16+
from .distribution import Distribution
17+
from .normal import Normal
18+
from .uniform import Uniform
19+
20+
__all__ = ['Categorical', 'Distribution', 'Normal', 'Uniform']
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import warnings
17+
18+
import numpy as np
19+
from paddle import _C_ops
20+
21+
from ..fluid import core
22+
from ..fluid.data_feeder import (check_dtype, check_type,
23+
check_variable_and_dtype, convert_dtype)
24+
from ..fluid.framework import in_dygraph_mode
25+
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
26+
elementwise_mul, elementwise_sub, nn, ops, tensor)
27+
from ..tensor import arange, concat, gather_nd, multinomial
28+
from .distribution import Distribution
29+
30+
31+
class Categorical(Distribution):
32+
r"""
33+
Categorical distribution is a discrete probability distribution that
34+
describes the possible results of a random variable that can take on
35+
one of K possible categories, with the probability of each category
36+
separately specified.
37+
38+
The probability mass function (pmf) is:
39+
40+
.. math::
41+
42+
pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]}
43+
44+
In the above equation:
45+
46+
* :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise.
47+
48+
Args:
49+
logits(list|tuple|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
50+
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
51+
52+
Examples:
53+
.. code-block:: python
54+
55+
import paddle
56+
from paddle.distribution import Categorical
57+
58+
paddle.seed(100) # on CPU device
59+
x = paddle.rand([6])
60+
print(x)
61+
# [0.5535528 0.20714243 0.01162981
62+
# 0.51577556 0.36369765 0.2609165 ]
63+
64+
paddle.seed(200) # on CPU device
65+
y = paddle.rand([6])
66+
print(y)
67+
# [0.77663314 0.90824795 0.15685187
68+
# 0.04279523 0.34468332 0.7955718 ]
69+
70+
cat = Categorical(x)
71+
cat2 = Categorical(y)
72+
73+
paddle.seed(1000) # on CPU device
74+
cat.sample([2,3])
75+
# [[0, 0, 5],
76+
# [3, 4, 5]]
77+
78+
cat.entropy()
79+
# [1.77528]
80+
81+
cat.kl_divergence(cat2)
82+
# [0.071952]
83+
84+
value = paddle.to_tensor([2,1,3])
85+
cat.probs(value)
86+
# [0.00608027 0.108298 0.269656]
87+
88+
cat.log_prob(value)
89+
# [-5.10271 -2.22287 -1.31061]
90+
91+
"""
92+
93+
def __init__(self, logits, name=None):
94+
"""
95+
Args:
96+
logits(list|tuple|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
97+
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
98+
"""
99+
if not in_dygraph_mode():
100+
check_type(logits, 'logits',
101+
(np.ndarray, tensor.Variable, list, tuple),
102+
'Categorical')
103+
104+
self.name = name if name is not None else 'Categorical'
105+
self.dtype = 'float32'
106+
107+
if self._validate_args(logits):
108+
self.logits = logits
109+
self.dtype = convert_dtype(logits.dtype)
110+
else:
111+
if isinstance(logits, np.ndarray) and str(
112+
logits.dtype) in ['float32', 'float64']:
113+
self.dtype = logits.dtype
114+
self.logits = self._to_tensor(logits)[0]
115+
if self.dtype != convert_dtype(self.logits.dtype):
116+
self.logits = tensor.cast(self.logits, dtype=self.dtype)
117+
118+
def sample(self, shape):
119+
"""Generate samples of the specified shape.
120+
121+
Args:
122+
shape (list): Shape of the generated samples.
123+
124+
Returns:
125+
Tensor: A tensor with prepended dimensions shape.
126+
127+
Examples:
128+
.. code-block:: python
129+
130+
import paddle
131+
from paddle.distribution import Categorical
132+
133+
paddle.seed(100) # on CPU device
134+
x = paddle.rand([6])
135+
print(x)
136+
# [0.5535528 0.20714243 0.01162981
137+
# 0.51577556 0.36369765 0.2609165 ]
138+
139+
cat = Categorical(x)
140+
141+
paddle.seed(1000) # on CPU device
142+
cat.sample([2,3])
143+
# [[0, 0, 5],
144+
# [3, 4, 5]]
145+
146+
"""
147+
name = self.name + '_sample'
148+
if not in_dygraph_mode():
149+
check_type(shape, 'shape', (list), 'sample')
150+
151+
num_samples = np.prod(np.array(shape))
152+
153+
logits_shape = list(self.logits.shape)
154+
if len(logits_shape) > 1:
155+
sample_shape = shape + logits_shape[:-1]
156+
logits = nn.reshape(self.logits,
157+
[np.prod(logits_shape[:-1]), logits_shape[-1]])
158+
else:
159+
sample_shape = shape
160+
logits = self.logits
161+
162+
sample_index = multinomial(logits, num_samples, True)
163+
return nn.reshape(sample_index, sample_shape, name=name)
164+
165+
def kl_divergence(self, other):
166+
"""The KL-divergence between two Categorical distributions.
167+
168+
Args:
169+
other (Categorical): instance of Categorical. The data type is float32.
170+
171+
Returns:
172+
Tensor: kl-divergence between two Categorical distributions.
173+
174+
Examples:
175+
.. code-block:: python
176+
177+
import paddle
178+
from paddle.distribution import Categorical
179+
180+
paddle.seed(100) # on CPU device
181+
x = paddle.rand([6])
182+
print(x)
183+
# [0.5535528 0.20714243 0.01162981
184+
# 0.51577556 0.36369765 0.2609165 ]
185+
186+
paddle.seed(200) # on CPU device
187+
y = paddle.rand([6])
188+
print(y)
189+
# [0.77663314 0.90824795 0.15685187
190+
# 0.04279523 0.34468332 0.7955718 ]
191+
192+
cat = Categorical(x)
193+
cat2 = Categorical(y)
194+
195+
cat.kl_divergence(cat2)
196+
# [0.071952]
197+
198+
"""
199+
name = self.name + '_kl_divergence'
200+
if not in_dygraph_mode():
201+
check_type(other, 'other', Categorical, 'kl_divergence')
202+
203+
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
204+
other_logits = other.logits - nn.reduce_max(
205+
other.logits, dim=-1, keep_dim=True)
206+
e_logits = ops.exp(logits)
207+
other_e_logits = ops.exp(other_logits)
208+
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
209+
other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True)
210+
prob = e_logits / z
211+
kl = nn.reduce_sum(
212+
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)),
213+
dim=-1,
214+
keep_dim=True,
215+
name=name)
216+
217+
return kl
218+
219+
def entropy(self):
220+
"""Shannon entropy in nats.
221+
222+
Returns:
223+
Tensor: Shannon entropy of Categorical distribution. The data type is float32.
224+
225+
Examples:
226+
.. code-block:: python
227+
228+
import paddle
229+
from paddle.distribution import Categorical
230+
231+
paddle.seed(100) # on CPU device
232+
x = paddle.rand([6])
233+
print(x)
234+
# [0.5535528 0.20714243 0.01162981
235+
# 0.51577556 0.36369765 0.2609165 ]
236+
237+
cat = Categorical(x)
238+
239+
cat.entropy()
240+
# [1.77528]
241+
242+
"""
243+
name = self.name + '_entropy'
244+
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
245+
e_logits = ops.exp(logits)
246+
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
247+
prob = e_logits / z
248+
249+
neg_entropy = nn.reduce_sum(
250+
prob * (logits - nn.log(z)), dim=-1, keep_dim=True)
251+
entropy = nn.scale(neg_entropy, scale=-1.0, name=name)
252+
return entropy
253+
254+
def probs(self, value):
255+
"""Probabilities of the given category (``value``).
256+
257+
If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as
258+
category, and the others represents the different distributions.
259+
At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the
260+
same number of distributions as ``logits``.
261+
If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions
262+
with ``logits. That is, ``value[:-1] = logits[:-1]``.
263+
264+
Args:
265+
value (Tensor): The input tensor represents the selected category index.
266+
267+
Returns:
268+
Tensor: probability according to the category index.
269+
270+
Examples:
271+
.. code-block:: python
272+
273+
import paddle
274+
from paddle.distribution import Categorical
275+
276+
paddle.seed(100) # on CPU device
277+
x = paddle.rand([6])
278+
print(x)
279+
# [0.5535528 0.20714243 0.01162981
280+
# 0.51577556 0.36369765 0.2609165 ]
281+
282+
cat = Categorical(x)
283+
284+
value = paddle.to_tensor([2,1,3])
285+
cat.probs(value)
286+
# [0.00608027 0.108298 0.269656]
287+
288+
"""
289+
name = self.name + '_probs'
290+
291+
dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True)
292+
prob = self.logits / dist_sum
293+
294+
shape = list(prob.shape)
295+
value_shape = list(value.shape)
296+
if len(shape) == 1:
297+
num_value_in_one_dist = np.prod(value_shape)
298+
index_value = nn.reshape(value, [num_value_in_one_dist, 1])
299+
index = index_value
300+
else:
301+
num_dist = np.prod(shape[:-1])
302+
num_value_in_one_dist = value_shape[-1]
303+
prob = nn.reshape(prob, [num_dist, shape[-1]])
304+
if len(value_shape) == 1:
305+
value = nn.expand(value, [num_dist])
306+
value_shape = shape[:-1] + value_shape
307+
index_value = nn.reshape(value, [num_dist, -1, 1])
308+
if shape[:-1] != value_shape[:-1]:
309+
raise ValueError(
310+
"shape of value {} must match shape of logits {}".format(
311+
str(value_shape[:-1]), str(shape[:-1])))
312+
313+
index_prefix = nn.unsqueeze(
314+
arange(
315+
num_dist, dtype=index_value.dtype), axes=-1)
316+
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
317+
index_prefix = nn.unsqueeze(index_prefix, axes=-1)
318+
319+
if index_value.dtype != index_prefix.dtype:
320+
tensor.cast(index_prefix, dtype=index_value.dtype)
321+
index = concat([index_prefix, index_value], axis=-1)
322+
323+
# value is the category index to search for the corresponding probability.
324+
select_prob = gather_nd(prob, index)
325+
return nn.reshape(select_prob, value_shape, name=name)
326+
327+
def log_prob(self, value):
328+
"""Log probabilities of the given category. Refer to ``probs`` method.
329+
330+
Args:
331+
value (Tensor): The input tensor represents the selected category index.
332+
333+
Returns:
334+
Tensor: Log probability.
335+
336+
Examples:
337+
.. code-block:: python
338+
339+
import paddle
340+
from paddle.distribution import Categorical
341+
342+
paddle.seed(100) # on CPU device
343+
x = paddle.rand([6])
344+
print(x)
345+
# [0.5535528 0.20714243 0.01162981
346+
# 0.51577556 0.36369765 0.2609165 ]
347+
348+
cat = Categorical(x)
349+
350+
value = paddle.to_tensor([2,1,3])
351+
cat.log_prob(value)
352+
# [-5.10271 -2.22287 -1.31061]
353+
354+
"""
355+
name = self.name + '_log_prob'
356+
357+
return nn.log(self.probs(value), name=name)

0 commit comments

Comments
 (0)