Skip to content

Need to reenable ZeRO1 for GPU to enable coverage for reduce-scatter/all-gather #6260

@jeffhataws

Description

@jeffhataws

🐛 Bug

Currently ZeRO1 test/test_zero1.py is disabled for GPU since version 2.1 (#4912). We should reenable it for GPU to enable coverage for reduce-scatter/all-gather.

When I tried with torch/xla version 2.2 (sha 7c46e4c), I hit a segmenation fault:

---------------------------------------------------------------------- Ran 1 test in 1.428s OK Segmentation fault (core dumped) 

To Reproduce

Steps to reproduce the behavior:

  1. Build torch/xla as in https://github.com/pytorch/xla/blob/master/docs/gpu.md
  2. Edit test/test_zero1.py and remove/comment-out the line that starts with
@unittest.skipIf(pjrt.device_type() == 'GPU', "TODO(alanwaketan): Fix it for the token change.") 
  1. Run the test
GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA python test/test_zero1.py GPU_NUM_DEVICES=2 PJRT_DEVICE=CUDA python test/test_zero1.py 

Expected behavior

Test runs and passes on GPUs without segfault

Environment

  • Reproducible on XLA backend [CPU/TPU]: GPU/CUDA
  • torch_xla version: 2.1

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions