Skip to content

Conversation

@jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Nov 19, 2023

For async checkpointing, it is important to unblock training as quickly as possible. Training can only be unblocked when all data has been moved off of device onto the host to free up device memory.

One bottleneck I found was that in TransferFromServer, the tensors are transferred using ToLiteralSync, meaning each tensor is transferred sequentially. In benchmarking a 2B parameter model, parallelizing these transfers decreased the time spent in TransferFromServer from 5.1s to 1.8s, ~65% reduction.

There is still significant overhead from copying the resulting xla::Literal into torch.Tensor, but that's for another PR.

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jonb377 jonb377 merged commit b9475d9 into master Nov 29, 2023
@jonb377 jonb377 deleted the jonbolin/d2h branch November 29, 2023 21:40
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants