Skip to content

Conversation

@yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Mar 19, 2024

The auto-sharding pass started complaining about this for default mesh-shape (works for non-default shape). There is a fix CL in XLA, and we also need to land this change to work with it.

This also moves OpenXLA pin to 25c8a6781af6be51d3bc43a0953b07803ab761ea, the companion CL.

@yeounoh yeounoh added the distributed SPMD and other distributed things. label Mar 19, 2024
@yeounoh yeounoh requested a review from JackCaoG March 19, 2024 01:53
@yeounoh yeounoh self-assigned this Mar 19, 2024
@yeounoh
Copy link
Contributor Author

yeounoh commented Mar 19, 2024

Tested locally with the local libtpu build. cc @JackCaoG

@yeounoh yeounoh force-pushed the auto_spmd_mesh_ids branch from d8785ea to 2837c53 Compare March 19, 2024 04:51
@yeounoh
Copy link
Contributor Author

yeounoh commented Mar 19, 2024

The companion CL landed in XLA, the pin update didn't require any changes on our end. I will update the pin to 25c8a6781af6be51d3bc43a0953b07803ab761ea and verify again with CI

@yeounoh yeounoh force-pushed the auto_spmd_mesh_ids branch from 339222b to 0a66c90 Compare March 19, 2024 05:10
@yeounoh yeounoh requested review from lsy323 and zpcore March 19, 2024 05:22
Copy link
Collaborator

@lsy323 lsy323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@yeounoh
Copy link
Contributor Author

yeounoh commented Mar 19, 2024

The ResNet throghput looks good

| Training Device=xla:0/2 Epoch=1 Step=520 Loss=0.00356 Rate=1777.26 GlobalRate=540.29 Time=17:25:25 | Training Device=xla:0/1 Epoch=1 Step=520 Loss=0.00356 Rate=1776.53 GlobalRate=539.22 Time=17:25:25 | Training Device=xla:0/0 Epoch=1 Step=520 Loss=0.00356 Rate=1776.87 GlobalRate=538.80 Time=17:25:25 | Training Device=xla:0/3 Epoch=1 Step=520 Loss=0.00356 Rate=1776.78 GlobalRate=538.95 Time=17:25:25 | Training Device=xla:0/2 Epoch=1 Step=540 Loss=0.00336 Rate=1777.40 GlobalRate=554.56 Time=17:25:26 | Training Device=xla:0/1 Epoch=1 Step=540 Loss=0.00336 Rate=1777.12 GlobalRate=553.47 Time=17:25:26 | Training Device=xla:0/3 Epoch=1 Step=540 Loss=0.00336 Rate=1777.21 GlobalRate=553.20 Time=17:25:26 | Training Device=xla:0/0 Epoch=1 Step=540 Loss=0.00336 Rate=1776.73 GlobalRate=553.05 Time=17:25:26 | Training Device=xla:0/0 Epoch=1 Step=560 Loss=0.00318 Rate=1768.65 GlobalRate=566.92 Time=17:25:27 | Training Device=xla:0/1 Epoch=1 Step=560 Loss=0.00318 Rate=1768.31 GlobalRate=567.35 Time=17:25:27 | Training Device=xla:0/3 Epoch=1 Step=560 Loss=0.00318 Rate=1768.32 GlobalRate=567.07 Time=17:25:27 | Training Device=xla:0/2 Epoch=1 Step=560 Loss=0.00318 Rate=1768.07 GlobalRate=568.44 Time=17:25:27 
@yeounoh
Copy link
Contributor Author

yeounoh commented Mar 19, 2024

Waiting on #6772 to merge first cc @lsy323

@lsy323
Copy link
Collaborator

lsy323 commented Mar 19, 2024

#6772 landed, please go ahead and merge

@yeounoh yeounoh force-pushed the auto_spmd_mesh_ids branch from 0a66c90 to ada802e Compare March 19, 2024 22:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backport_2.3 distributed SPMD and other distributed things.

4 participants