Skip to content

Commit 74005df

Browse files
committed
【Hackathon 4 No.16】为 Paddle 新增 PoissonNLLLoss API
1 parent 1d7fe0f commit 74005df

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# PoissonNLLLoss 设计文档
2+
3+
| API名称 | PoissonNLLoss |
4+
| ------------------------------------------------------------ |------------------------------------------------|
5+
| 提交作者<input type="checkbox" class="rowselector hidden"> | LyndonKong |
6+
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-03-01 |
7+
| 版本号 | V1.0 |
8+
| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | Develop |
9+
| 文件名 | 20230301_api_design_for_poissonllloss.md<br> |
10+
11+
12+
# 一、概述
13+
14+
## 1、相关背景
15+
16+
paddle.nn.PoissonNLLLoss 和 paddle.nn.functional.Poisson_nll_loss API 用于计算真实标签服从泊松分布的负对数似然损失。
17+
该函数计算公式为:
18+
19+
$$
20+
\text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input}) + \log(\text{target!})
21+
$$
22+
23+
损失函数中的最后一项可以使用Stirling公式近似,对target的值超过1的索引处考虑此项近似,对target的值小于等于1的索引设置为0。
24+
25+
## 2、功能目标
26+
27+
在飞桨中增加 paddle.nn.PoissonNLLLoss 和 paddle.nn.functional.Poisson_nll_loss API。
28+
29+
## 3、意义
30+
31+
飞桨将支持 paddle.nn.PoissonNLLLoss 和 paddle.nn.functional.Poisson_nll_loss API。
32+
33+
# 二、飞桨现状
34+
35+
飞桨中还没有 PoissonNLLLoss API,可以简单通过log,exp等函数构造该函数。
36+
37+
38+
# 三、业内方案调研
39+
40+
PyTorch:PyTorch 支持 torch.nn.poissonNLLLoss 和 torch.nn.functional.Poisson_nll_loss,由python代码提供接口:
41+
42+
```python
43+
def poisson_nll_loss(
44+
input: Tensor,
45+
target: Tensor,
46+
log_input: bool = True,
47+
full: bool = False,
48+
size_average: Optional[bool] = None,
49+
eps: float = 1e-8,
50+
reduce: Optional[bool] = None,
51+
reduction: str = "mean",
52+
) -> Tensor:
53+
if has_torch_function_variadic(input, target):
54+
return handle_torch_function(
55+
poisson_nll_loss,
56+
(input, target),
57+
input,
58+
target,
59+
log_input=log_input,
60+
full=full,
61+
size_average=size_average,
62+
eps=eps,
63+
reduce=reduce,
64+
reduction=reduction,
65+
)
66+
if size_average is not None or reduce is not None:
67+
reduction = _Reduction.legacy_get_string(size_average, reduce)
68+
if reduction != "none" and reduction != "mean" and reduction != "sum":
69+
ret = input
70+
raise ValueError(reduction + " is not valid")
71+
72+
ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction))
73+
return ret
74+
```
75+
76+
由cpp提供具体实现
77+
```cpp
78+
Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
79+
{
80+
Tensor loss;
81+
if (log_input) {
82+
loss = at::exp(input) - target * input;
83+
} else {
84+
loss = input - target * at::log(input + eps);
85+
}
86+
87+
if (full) {
88+
auto stirling_term = target * at::log(target) - target + 0.5 * at::log(2 * c10::pi<double> * target);
89+
loss += stirling_term.masked_fill(target <= 1, 0);
90+
}
91+
92+
return apply_loss_reduction(loss, reduction);
93+
}
94+
95+
static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
96+
if (reduction == at::Reduction::Mean) {
97+
return unreduced.mean();
98+
} else if (reduction == at::Reduction::Sum) {
99+
return unreduced.sum();
100+
}
101+
return unreduced;
102+
}
103+
```
104+
105+
无其它相关库支持该 Loss 函数。
106+
107+
# 四、对比分析
108+
109+
设计方案将参考PyTorch的实现,在当前版本的API当中仅实现Python版本。由于PyTroch实现当中参数 ``size_average`` 和 ``reduce`` 的功能因为被 ``reduction``包含而废弃,在此版本的API实现我们中对这些被抛弃的参数不再做兼容。
110+
111+
# 五、设计思路与实现方案
112+
113+
## 命名与参数设计
114+
115+
共添加以下两个 API:
116+
117+
`paddle.nn.functional.poisson_nll_loss(
118+
input,
119+
target,
120+
log_input = True,
121+
full = False,
122+
eps = 1e-8,
123+
reduction = "mean",
124+
name:str = None
125+
) -> Tensor:`
126+
- Input(Tensor): 期望服从泊松分布的输入,形状为`(N, *)` 或 `(*)` 其中 `*`表示任何数量的额外维度。
127+
- Target(Tensor): 为泊松分布的随机样本,形状为`(N, *)` 或 `(*)`,与输入的形状相同,
128+
或与输入的形状相同但有一个维度等于1(允许广播)。
129+
- Log_input(bool): 输入和目标是否为对数。如果为`True`,则损失函数的前两项的计算方式为$\exp(\text{input}) - \exp\text{target} * \text{target}$。如果设置为`False`,则损失函数的前两项计算方式为$\text{input} - \text{target} * \log(\text{input}+\text{eps})$。默认为`True`。
130+
- Full(bool):是否计算完整的损失。如果为`True`,则添加Stirling逼近项$\text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target})$。默认为`False`。
131+
- Eps: 避免在`log_input=False`时计算$\log(0)$的小量。默认值为1e-8/。
132+
- Reduction:指定应用于输出结果的计算方式,指定应用于输出结果的计算方式,可选值有:"none", "mean", "sum"。默认为"mean",计算`Poisson_nll_loss`的均值;设置为"sum"时,计算`Poisson_nll_loss`的总和;设置为"none"时,则返回`Poisson_nll_loss`。
133+
- Name: 操作的名称,默认为None。
134+
135+
136+
137+
`paddle.nn.PoissonNLLLoss(
138+
log_input = True,
139+
full = False,
140+
eps = 1e-8,
141+
reduction = "mean",
142+
name = None
143+
) -> Tensor:`
144+
- log_input(bool): 输入和目标是否为对数。如果为`True`,则损失函数的前两项的计算方式为$\exp(\text{input}) - \exp\text{target} * \text{target}$。如果设置为`False`,则损失函数的前两项计算方式为$\text{input} - \text{target} * \log(\text{input}+\text{eps})$。默认为`True`。
145+
- Full(bool):是否计算完整的损失。如果为`True`,则添加Stirling逼近项$\text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target})$。默认为`False`。
146+
- Eps: 避免在`log_input=False`时计算$\log(0)$的小量。默认值为1e-8/。
147+
- Reduction:指定应用于输出结果的计算方式,指定应用于输出结果的计算方式,可选值有:"none", "mean", "sum"。默认为"mean",计算`Poisson_nll_loss`的均值;设置为"sum"时,计算`Poisson_nll_loss`的总和;设置为"none"时,则返回`Poisson_nll_loss`。
148+
- Name: 操作的名称,默认为None。
149+
150+
参数与文档要求进行对齐。
151+
152+
## API实现方案
153+
154+
参考 pytorch 的处理方式通过 paddle.exp, paddle.log , paddle.where函数实现。
155+
1. 检查参数
156+
157+
1. 检查 reduction 有效性(同其余 functional loss 中的实现)
158+
2. 检查输入参数 eps 是否为正数
159+
3. 检查输入(含 `input`、`target`)的size和dtype(同其余 functional loss 中的实现)
160+
161+
2. 计算
162+
163+
1. 判断`log_input`是否为`True`,计算loss的前两项
164+
2. 判断`full`是否为`True`,计算Stirling逼近项
165+
3. 计算loss
166+
167+
168+
3. 根据 `reduction`,输出 loss(同其余 functional loss 中的实现)
169+
170+
# 六、测试和验收的考量
171+
172+
由于在numpy当中没有Poisson_nll_loss的实现,我们基于numpy自己实现了此函数,并于参考方案对比验证了numpy实现的正确性。在此基础上我们进行了前向的测验和验收:
173+
1. 结果正确性:
174+
175+
- 前向计算:`paddle.nn.PoissonNLLLoss` 和 `paddle.nn.functional.poisson_nll_loss` 计算结果与numpy实现计算结果一致。
176+
- 反向计算:由 Python API 组合新增 API 无需验证反向计算。
177+
178+
2. 硬件场景: 在 CPU 和 GPU 硬件条件下的运行结果一致。
179+
180+
3. 异常测试:
181+
182+
- 数据类型检验:
183+
- input和target的数据类型检验
184+
- 可选参数的数据类型检验
185+
- 具体数值检验:
186+
- input 与 target 的维度一致检查
187+
- 若 eps 有输入, 则要为正
188+
189+
190+
# 七、可行性分析和排期规划
191+
方案主要依赖现有 paddle API 组合,待该设计文档通过验收后可尽快提交。
192+
193+
# 八、影响面
194+
195+
在 paddle.nn.functional.loss 文件中import math,新增的API对其他模块没有影响
196+
197+
# 名词解释
198+
199+
# 附件及参考资料
200+
201+
[torch实现](https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#poisson_nll_loss)

0 commit comments

Comments
 (0)