[0-size Tensor No.271] Add 0-size Tensor support for paddle.take_along_axis API. #73736
+85 −0
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
PR Category
Operator Mechanism
PR Types
Bug fixes
Description
本次提交为完成任务 #72637 中关于
paddle.take_along_axis的部分。修改历程介绍如下:
问题复现与分析:
PaddleAPITest复现BUG。在PaddleAPITest的--accuracy=True精度对比模式下,由于paddle.take_along_axis(arr, index, axis)与torch.take_along_dim(input, indices, dim)的参数名不统一,反复出现TypeError: missing a required argument的参数绑定错误,无法准确定位问题。PaddleAPITest的--paddle_only=True模式后,成功触发了Python层的TypeError: take_along_axis() got an unexpected keyword argument 'index'报错,这暴露了API前后端参数名不一致的问题。unittest的OpTest框架编写单元测试。在未修复Kernel的情况下,成功复现了底层的C++错误:前向修复 (Forward Fix):
a. 定位API: 在
Paddle/python/paddle/tensor/manipulation.py中找到了def take_along_axis(...)的Python定义,其核心实现调用了_C_ops.take_along_axis。b. 定位算子定义: 使用
grep发现,该算子没有独立的.yml文件,其定义位于paddle/phi/ops/yaml/ops.yaml中。c. 检查InferMeta: 根据
ops.yaml的指引,在paddle/phi/infermeta/binary.cc中找到了TakeAlongAxisInferMeta函数。经分析,其out->set_dims(index.dims())逻辑能正确推导0-size Tensor的输出形状,无需修改。d. 修改Kernel:
* 根据
grep结果,定位到CPU Kernel文件为paddle/phi/kernels/cpu/take_along_axis_kernel.cc。* 参照标准修复范式,在
TakeAlongAxisKernel函数开头加入了对0-size情况的保护。核心逻辑是判断index.numel()是否为0,因为输出的形状完全由index决定。* 修复代码如下:
d. 依照以上原则修改CPU、GPU、XPU Kernel
反向修复 (Backward Fix):
4. 添加单测 (Add Unit Test):
在
test/legacy_test/test_take_along_axis_op.py文件中,为彻底解决因父类TestTakeAlongAxisOp的setUp方法无法兼容0-size数据而导致的CI报错(IndexError,ValueError),最终方案是放弃继承,添加了两个全新的、独立的OpTest测试类,分别覆盖两种不同的0-size边界场景。测试场景一:输入
arr为0-size,但index不为0-sizeTestTakeAlongAxis0Size1类进行验证。完整测试代码如下:测试场景二:索引
index为0-size,但arr不为0-sizeTestTakeAlongAxis0Size2类进行验证。完整测试代码如下:feature/fix_take_along_axis_0size分支上运行添加的OpTest单元测试,结果为OK,证明修复成功。--accuracy模式无法使用。在--paddle_only模式下,修复后可顺利通过。pcard-67164