[Auto Parallel] FIx inplaced ops save wrong dist_attr for backward in auto parallel #73836
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
PR Category
Auto Parallel
PR Types
Bug fixes
Description
在inplaced操作中,api传入参数为引用,这意味这api的input参数与output参数指向同一个地址。由于SetGradOutMeta是在调用api后记录,此时记录的参数是已经被api修改后的错误值,而非api实际传入的参数。
在非自动并行中,SetGradOutMeta仅记录传入x的meta相关参数;但在自动并行中,SetGradOutMeta会同时记录经过api修改后传入x的dist_attr以及dims,这些错误的参数会导致反向计算时该node自动推导的过程中得到错误的dist_attr。
本pr通过在调用api前保存一份正确的dist_attr以及dims,并用这些值来设置SetGradOutMeta从而修复该问题。如下图所示,其中绿色为正确操作,红色为错误操作,虚线为本pr实现的操作。
TODO: 由于此更改应应用于所有inplaced操作,影响范围较广,目前仅限制在reshape_上生效
pcard-86802
修改前的dygraph_functions.cc
修改后的dygraph_functions.cc