|
| 1 | +# PoissonNLLLoss 设计文档 |
| 2 | + |
| 3 | +| API名称 | PoissonNLLoss | |
| 4 | +| ------------------------------------------------------------ |------------------------------------------------| |
| 5 | +| 提交作者<input type="checkbox" class="rowselector hidden"> | LyndonKong | |
| 6 | +| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-03-01 | |
| 7 | +| 版本号 | V1.0 | |
| 8 | +| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | Develop | |
| 9 | +| 文件名 | 20230301_api_design_for_poissonllloss.md<br> | |
| 10 | + |
| 11 | + |
| 12 | +# 一、概述 |
| 13 | + |
| 14 | +## 1、相关背景 |
| 15 | + |
| 16 | +paddle.nn.PoissonNLLLoss 和 paddle.nn.functional.Poisson_nll_loss API 用于计算真实标签服从泊松分布的负对数似然损失。 |
| 17 | +该函数计算公式为: |
| 18 | + |
| 19 | +$$ |
| 20 | +\text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input}) + \log(\text{target!}) |
| 21 | +$$ |
| 22 | + |
| 23 | +损失函数中的最后一项可以使用Stirling公式近似,对target的值超过1的索引处考虑此项近似,对target的值小于等于1的索引设置为0。 |
| 24 | + |
| 25 | +## 2、功能目标 |
| 26 | + |
| 27 | +在飞桨中增加 paddle.nn.PoissonNLLLoss 和 paddle.nn.functional.Poisson_nll_loss API。 |
| 28 | + |
| 29 | +## 3、意义 |
| 30 | + |
| 31 | +飞桨将支持 paddle.nn.PoissonNLLLoss 和 paddle.nn.functional.Poisson_nll_loss API。 |
| 32 | + |
| 33 | +# 二、飞桨现状 |
| 34 | + |
| 35 | +飞桨中还没有 PoissonNLLLoss API,可以简单通过log,exp等函数构造该函数。 |
| 36 | + |
| 37 | + |
| 38 | +# 三、业内方案调研 |
| 39 | + |
| 40 | +PyTorch:PyTorch 支持 torch.nn.poissonNLLLoss 和 torch.nn.functional.Poisson_nll_loss,由python代码提供接口: |
| 41 | + |
| 42 | +```python |
| 43 | +def poisson_nll_loss( |
| 44 | + input: Tensor, |
| 45 | + target: Tensor, |
| 46 | + log_input: bool = True, |
| 47 | + full: bool = False, |
| 48 | + size_average: Optional[bool] = None, |
| 49 | + eps: float = 1e-8, |
| 50 | + reduce: Optional[bool] = None, |
| 51 | + reduction: str = "mean", |
| 52 | +) -> Tensor: |
| 53 | + if has_torch_function_variadic(input, target): |
| 54 | + return handle_torch_function( |
| 55 | + poisson_nll_loss, |
| 56 | + (input, target), |
| 57 | + input, |
| 58 | + target, |
| 59 | + log_input=log_input, |
| 60 | + full=full, |
| 61 | + size_average=size_average, |
| 62 | + eps=eps, |
| 63 | + reduce=reduce, |
| 64 | + reduction=reduction, |
| 65 | + ) |
| 66 | + if size_average is not None or reduce is not None: |
| 67 | + reduction = _Reduction.legacy_get_string(size_average, reduce) |
| 68 | + if reduction != "none" and reduction != "mean" and reduction != "sum": |
| 69 | + ret = input |
| 70 | + raise ValueError(reduction + " is not valid") |
| 71 | + |
| 72 | + ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction)) |
| 73 | + return ret |
| 74 | +``` |
| 75 | + |
| 76 | +由cpp提供具体实现 |
| 77 | +```cpp |
| 78 | + Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction) |
| 79 | + { |
| 80 | + Tensor loss; |
| 81 | + if (log_input) { |
| 82 | + loss = at::exp(input) - target * input; |
| 83 | + } else { |
| 84 | + loss = input - target * at::log(input + eps); |
| 85 | + } |
| 86 | + |
| 87 | + if (full) { |
| 88 | + auto stirling_term = target * at::log(target) - target + 0.5 * at::log(2 * c10::pi<double> * target); |
| 89 | + loss += stirling_term.masked_fill(target <= 1, 0); |
| 90 | + } |
| 91 | + |
| 92 | + return apply_loss_reduction(loss, reduction); |
| 93 | + } |
| 94 | + |
| 95 | + static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) { |
| 96 | + if (reduction == at::Reduction::Mean) { |
| 97 | + return unreduced.mean(); |
| 98 | + } else if (reduction == at::Reduction::Sum) { |
| 99 | + return unreduced.sum(); |
| 100 | + } |
| 101 | + return unreduced; |
| 102 | + } |
| 103 | +``` |
| 104 | +
|
| 105 | +无其它相关库支持该 Loss 函数。 |
| 106 | +
|
| 107 | +# 四、对比分析 |
| 108 | +
|
| 109 | +设计方案将参考PyTorch的实现,在当前版本的API当中仅实现Python版本。由于PyTroch实现当中参数 ``size_average`` 和 ``reduce`` 的功能因为被 ``reduction``包含而废弃,在此版本的API实现我们中对这些被抛弃的参数不再做兼容。 |
| 110 | +
|
| 111 | +# 五、设计思路与实现方案 |
| 112 | +
|
| 113 | +## 命名与参数设计 |
| 114 | +
|
| 115 | +共添加以下两个 API: |
| 116 | +
|
| 117 | +`paddle.nn.functional.poisson_nll_loss( |
| 118 | + input, |
| 119 | + target, |
| 120 | + log_input = True, |
| 121 | + full = False, |
| 122 | + eps = 1e-8, |
| 123 | + reduction = "mean", |
| 124 | + name:str = None |
| 125 | +) -> Tensor:` |
| 126 | + - Input(Tensor): 期望服从泊松分布的输入,形状为`(N, *)` 或 `(*)` 其中 `*`表示任何数量的额外维度。 |
| 127 | + - Target(Tensor): 为泊松分布的随机样本,形状为`(N, *)` 或 `(*)`,与输入的形状相同, |
| 128 | +或与输入的形状相同但有一个维度等于1(允许广播)。 |
| 129 | + - Log_input(bool): 输入和目标是否为对数。如果为`True`,则损失函数的前两项的计算方式为$\exp(\text{input}) - \exp\text{target} * \text{target}$。如果设置为`False`,则损失函数的前两项计算方式为$\text{input} - \text{target} * \log(\text{input}+\text{eps})$。默认为`True`。 |
| 130 | + - Full(bool):是否计算完整的损失。如果为`True`,则添加Stirling逼近项$\text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target})$。默认为`False`。 |
| 131 | + - Eps: 避免在`log_input=False`时计算$\log(0)$的小量。默认值为1e-8/。 |
| 132 | + - Reduction:指定应用于输出结果的计算方式,指定应用于输出结果的计算方式,可选值有:"none", "mean", "sum"。默认为"mean",计算`Poisson_nll_loss`的均值;设置为"sum"时,计算`Poisson_nll_loss`的总和;设置为"none"时,则返回`Poisson_nll_loss`。 |
| 133 | + - Name: 操作的名称,默认为None。 |
| 134 | +
|
| 135 | +和 |
| 136 | +
|
| 137 | +`paddle.nn.PoissonNLLLoss( |
| 138 | + log_input = True, |
| 139 | + full = False, |
| 140 | + eps = 1e-8, |
| 141 | + reduction = "mean", |
| 142 | + name = None |
| 143 | +) -> Tensor:` |
| 144 | +- log_input(bool): 输入和目标是否为对数。如果为`True`,则损失函数的前两项的计算方式为$\exp(\text{input}) - \exp\text{target} * \text{target}$。如果设置为`False`,则损失函数的前两项计算方式为$\text{input} - \text{target} * \log(\text{input}+\text{eps})$。默认为`True`。 |
| 145 | + - Full(bool):是否计算完整的损失。如果为`True`,则添加Stirling逼近项$\text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target})$。默认为`False`。 |
| 146 | + - Eps: 避免在`log_input=False`时计算$\log(0)$的小量。默认值为1e-8/。 |
| 147 | + - Reduction:指定应用于输出结果的计算方式,指定应用于输出结果的计算方式,可选值有:"none", "mean", "sum"。默认为"mean",计算`Poisson_nll_loss`的均值;设置为"sum"时,计算`Poisson_nll_loss`的总和;设置为"none"时,则返回`Poisson_nll_loss`。 |
| 148 | + - Name: 操作的名称,默认为None。 |
| 149 | +
|
| 150 | +参数与文档要求进行对齐。 |
| 151 | +
|
| 152 | +## API实现方案 |
| 153 | +
|
| 154 | +参考 pytorch 的处理方式通过 paddle.exp, paddle.log , paddle.where函数实现。 |
| 155 | +1. 检查参数 |
| 156 | +
|
| 157 | + 1. 检查 reduction 有效性(同其余 functional loss 中的实现) |
| 158 | + 2. 检查输入参数 eps 是否为正数 |
| 159 | + 3. 检查输入(含 `input`、`target`)的size和dtype(同其余 functional loss 中的实现) |
| 160 | +
|
| 161 | +2. 计算 |
| 162 | +
|
| 163 | + 1. 判断`log_input`是否为`True`,计算loss的前两项 |
| 164 | + 2. 判断`full`是否为`True`,计算Stirling逼近项 |
| 165 | + 3. 计算loss |
| 166 | +
|
| 167 | +
|
| 168 | +3. 根据 `reduction`,输出 loss(同其余 functional loss 中的实现) |
| 169 | +
|
| 170 | +# 六、测试和验收的考量 |
| 171 | +
|
| 172 | +由于在numpy当中没有Poisson_nll_loss的实现,我们基于numpy自己实现了此函数,并于参考方案对比验证了numpy实现的正确性。在此基础上我们进行了前向的测验和验收: |
| 173 | +1. 结果正确性: |
| 174 | +
|
| 175 | + - 前向计算:`paddle.nn.PoissonNLLLoss` 和 `paddle.nn.functional.poisson_nll_loss` 计算结果与numpy实现计算结果一致。 |
| 176 | + - 反向计算:由 Python API 组合新增 API 无需验证反向计算。 |
| 177 | +
|
| 178 | +2. 硬件场景: 在 CPU 和 GPU 硬件条件下的运行结果一致。 |
| 179 | +
|
| 180 | +3. 异常测试: |
| 181 | +
|
| 182 | + - 数据类型检验: |
| 183 | + - input和target的数据类型检验 |
| 184 | + - 可选参数的数据类型检验 |
| 185 | + - 具体数值检验: |
| 186 | + - input 与 target 的维度一致检查 |
| 187 | + - 若 eps 有输入, 则要为正 |
| 188 | +
|
| 189 | +
|
| 190 | +# 七、可行性分析和排期规划 |
| 191 | +方案主要依赖现有 paddle API 组合,待该设计文档通过验收后可尽快提交。 |
| 192 | +
|
| 193 | +# 八、影响面 |
| 194 | +
|
| 195 | +在 paddle.nn.functional.loss 文件中import math,新增的API对其他模块没有影响 |
| 196 | +
|
| 197 | +# 名词解释 |
| 198 | +
|
| 199 | +# 附件及参考资料 |
| 200 | +
|
| 201 | +[torch实现](https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#poisson_nll_loss) |
0 commit comments