Skip to content

Conversation

@rpsilva-aws
Copy link
Collaborator

Currently, the existing parameter mapping for the lowering context is not well suited for SPMD. In case of large models, it will cause a large synchronous bottleneck when transferring all device data to the host. This is caused by each ReplicateShardedData computation that gathers and reassembles each sharded data across multiple devices. This is by design, since it is expected to collect all parameters regardless of their allocation.

In this PR, we introduce a new mapping that does not invoke the sharded replication, but instead uses references to the device data. This is generally sufficient and preferred in most cases, where the user only wants to access the validate parameters (those that are not returned as -1 from tensor_parameter_id, as 'fake' parameters).

@rpsilva-aws
Copy link
Collaborator Author

Re-opened from #8453, cleaned up the merge commit.

@tengyifei tengyifei self-requested a review December 5, 2024 23:51
@tengyifei tengyifei added the tpuci label Dec 5, 2024
@tengyifei tengyifei marked this pull request as ready for review December 5, 2024 23:52
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_lc_mapping_v3 branch from 8fd7ac7 to 9858577 Compare December 5, 2024 23:53
@tengyifei tengyifei merged commit 5d11f66 into pytorch:master Dec 7, 2024
12 checks passed
@rpsilva-aws rpsilva-aws deleted the rpsilva_lc_mapping_v3 branch December 9, 2024 19:03
tengyifei added a commit that referenced this pull request Jan 2, 2025
Previously scan uses `parameter_id_tensor_mapping` to fetch tensors hoisted to HLO parameters e.g. the fn being scanned may create additional tensors while its running. `parameter_id_tensor_mapping` will fetch those tensors back to host as XLA literals and create new tensors wrapphing those, resulting in additional host RAM usage. PR #8460 added `device_parameter_id_tensor_mapping` that returns the actual device backed tensors instead of another copy. So we'll use that and test that this avoids host transfers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants