Skip to content

Assertion error when scan over model w/o weights #8753

@zpcore

Description

@zpcore

🐛 Bug

When scanning over modules that doesn't contain model weights, it triggers assertion error:

 File "/workspaces/pytorch/xla/torch_xla/experimental/scan_layers.py", line 93, in scan_layers final_carry, _ = scan( File "/workspaces/pytorch/xla/torch_xla/experimental/scan.py", line 156, in scan raise ValueError(f"`xs` {xs} is an empty PyTree.") ValueError: `xs` ({}, {}) is an empty PyTree. 

To Reproduce

Run the test in 1. This test will fail and the output of fake_fa_wrapper triggers raise ValueError(f"xs {xs} is an empty PyTree.").

Expected behavior

The output of fake_fa_wrapper should not trigger assertion error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions