- Notifications
You must be signed in to change notification settings - Fork 5.9k
Add probability distribution transformation APIs #40536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add probability distribution transformation APIs #40536
Conversation
| ✅ This PR's description meets the template requirements! |
| Thanks for your contribution! |
e7850b7 to 141c642 Compare 141c642 to ae14d10 Compare ae14d10 to 540de48 Compare 540de48 to b14dae1 Compare f37b54c to 9fcd9e8 Compare 9fcd9e8 to 983c1fb Compare 983c1fb to b96f2e1 Compare b96f2e1 to 3421235 Compare 3421235 to 6985966 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
TCChenlong left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
TODO:Add Chinese documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"power" means paddle.distribution.PowerTransform?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新,粘贴错误,粘贴了一个旧的代码示例
| Examples in Describe above: print(affine.inverse(power.forward(x))), "power" means paddle.distribution.PowerTransform? or should be affine? |
是 ''affine'',粘贴错误,文档和PR描述均已更新 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add some abstract methods in base class, e.g. mean, variance and rsample
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加,这样能保持不同子类持有的方法一致;存量 normal, uniform, categorical目前是NotImplementedError,按照设计文档计划后续统一更新
6985966 to 176b413 Compare 176b413 to 4af8502 Compare 4af8502 to 723a279 Compare 723a279 to c8ef422 Compare
jeff41404 left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
XiaoguangHu01 left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| from paddle.distribution.kl import kl_divergence, register_kl | ||
| from paddle.distribution.multinomial import Multinomial | ||
| from paddle.distribution.normal import Normal | ||
| from paddle.distribution.transform import * # noqa: F403 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要import *的原因是什么呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distribution/transform.py 文件中定义了需要公开API的__ALL__列表,在distribution/__init__中用import * 全部导出,并添加到__init__.py 的__all__列表中, 可以通过paddle.distribution.xxx访问,访问路径和竞品保持一致
PR types
New features
PR changes
APIs
Describe
Adds 13 transformation APIs and 2 distribution APIs :
new transformation APIs:
new distribution APIs:
Examples: