- Notifications
You must be signed in to change notification settings - Fork 9.8k
Add Sequence parallel and 2D parallel examples #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits Select commit Hold shift + click to select a range
53cc387 Add Sequence parallel and 2D parallel examples
fduwjj a9c7bb5 Split files and extract common logic
fduwjj 56d684b Update test script
fduwjj 8b114dc Add sq safe guard since it is still not released
fduwjj 27362ce fix import
fduwjj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
83 changes: 83 additions & 0 deletions 83 distributed/tensor_parallelism/sequence_parallel_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import argparse | ||
| | ||
| import torch | ||
| import torch.multiprocessing as mp | ||
| | ||
| from torch.distributed._tensor import DeviceMesh | ||
| from torch.distributed.tensor.parallel import parallelize_module | ||
| from utils import cleanup, setup, ToyModel | ||
| | ||
| try: | ||
| from torch.distributed.tensor.parallel import ( | ||
| SequenceParallel | ||
| ) | ||
| SP_AVAILABLE = True | ||
| except BaseException as e: | ||
| pass | ||
| | ||
| | ||
| """ | ||
| This is the script to test Sequence Parallel(SP) on a toy model in a | ||
| Megetron-LM SPMD style. We show an E2E working flow from forward, | ||
| backward and optimization. | ||
| | ||
| We use the example of two `nn.Linear` layers with an element-wise `nn.RELU` | ||
| in between to show an example of sequence parallel, which was proposed in paper: | ||
| | ||
| https://arxiv.org/pdf/2205.05198.pdf. | ||
| | ||
| Like tensor parallel, we parallelize the first linear layer by column | ||
| and also parallelize the second linear layer by row. But the input in each rank | ||
| now is different so that we need one all-gather for input and one reduce-scatter | ||
| in the end of the second linear layer. | ||
| """ | ||
| | ||
| | ||
| def demo_sp(rank, args): | ||
| """ | ||
| Main body of the demo of a basic version of sequence parallel by using | ||
| PyTorch native APIs. | ||
| """ | ||
| print(f"Running SP example on rank {rank}.") | ||
| setup(rank, args.world_size) | ||
| | ||
| # create a sharding plan based on the given world_size. | ||
| device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) | ||
| | ||
| # create model and move it to GPU with id rank | ||
| model = ToyModel().cuda(rank) | ||
| # Create a optimizer for the parallelized module. | ||
| LR = 0.25 | ||
| optimizer = torch.optim.SGD(model.parameters(), lr=LR) | ||
| # Parallelize the module based on the given Parallel Style. | ||
| model = parallelize_module(model, device_mesh, SequenceParallel()) | ||
| | ||
| # Perform a num of iterations of forward/backward | ||
| # and optimizations for the sharded module. | ||
| for _ in range(args.iter_nums): | ||
| # For SP, input can be different across all ranks. | ||
| inp = torch.rand(20, 10).cuda(rank) | ||
| output = model(inp) | ||
| output.sum().backward() | ||
| optimizer.step() | ||
| | ||
| cleanup() | ||
| | ||
| | ||
| if __name__ == "__main__": | ||
| n_gpus = torch.cuda.device_count() | ||
| parser = argparse.ArgumentParser() | ||
| # This is passed in via cmd | ||
| parser.add_argument("--world_size", type=int, default=n_gpus) | ||
| parser.add_argument("--iter_nums", type=int, default=10) | ||
| args = parser.parse_args() | ||
| # The main entry point is called directly without using subprocess | ||
| if n_gpus < 2: | ||
| print("Requires at least 2 GPUs to run.") | ||
| elif not SP_AVAILABLE: | ||
| print( | ||
| "PyTorch doesn't have Sequence Parallelism available," | ||
| " need nightly build." | ||
| ) | ||
| else: | ||
| mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions 126 distributed/tensor_parallelism/two_d_parallel_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| import argparse | ||
| | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
| | ||
| from torch.distributed._tensor import DeviceMesh | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.tensor.parallel import ( | ||
| PairwiseParallel, | ||
| parallelize_module, | ||
| ) | ||
| from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp | ||
| | ||
| from utils import cleanup, setup, ToyModel | ||
| try: | ||
| from torch.distributed.tensor.parallel import ( | ||
| SequenceParallel | ||
| ) | ||
| SP_AVAILABLE = True | ||
| except BaseException as e: | ||
| pass | ||
| | ||
| | ||
| """ | ||
| This is the script to test 2D Parallel which combines Tensor/Sequence | ||
| parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model | ||
| in the SPMD style. We show an E2E working flow from forward, backward | ||
| and optimization. | ||
| | ||
| We enabled Fully Sharded Data Parallel + Tensor Parallel in | ||
| separate parallel dimensions: | ||
| Data Parallel across hosts | ||
| Tensor Parallel within each host | ||
| | ||
| We use a simple diagram to illustrate below: | ||
| | ||
| ====================================================================== | ||
| ------------ ------------ ------------ ------------ | ||
| | Host 1 | | Host 2 | | | | Host N | | ||
| | 8 GPUs | | 8 GPUs | | | | 8 GPUs | | ||
| | | | | | ... | | | | ||
| | (TP) | | (TP) | | | | (TP) | | ||
| |[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7| | ||
| | | | | | | | .., 8N-1]| | ||
| | | | | | | | | | ||
| ------------ ------------ ------------ ------------ | ||
| FSDP: | ||
| [0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] | ||
| ====================================================================== | ||
| | ||
| More details can be seen in the slide: | ||
| https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ | ||
| """ | ||
| | ||
| | ||
| def demo_2d(rank, args): | ||
| """ | ||
| Main body of the demo of a basic version of tensor parallel by using | ||
| PyTorch native APIs. | ||
| """ | ||
| print(f"Running basic Megatron style TP example on rank {rank}.") | ||
| setup(rank, args.world_size) | ||
| assert ( | ||
| args.world_size % args.tp_size == 0 | ||
| ), "World size needs to be divisible by TP size" | ||
| | ||
| # create a sharding plan based on the given world_size. | ||
| device_mesh = DeviceMesh( | ||
| "cuda", torch.arange(0, args.world_size).view(-1, args.tp_size) | ||
| ) | ||
| | ||
| # create model and move it to GPU with id rank | ||
| model = ToyModel().cuda(rank) | ||
| # Create a optimizer for the parallelized module. | ||
| LR = 0.25 | ||
| optimizer = torch.optim.SGD(model.parameters(), lr=LR) | ||
| # Parallelize the module based on the given Parallel Style. | ||
| parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() | ||
| model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1) | ||
| | ||
| # We need to register hooks for TP + FSDP integration. | ||
| assert ( | ||
| enable_2d_with_fsdp() | ||
fduwjj marked this conversation as resolved. Show resolved Hide resolved | ||
| ), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0" | ||
| model = FSDP(model) | ||
| | ||
| # Perform a num of iterations of forward/backward | ||
| # and optimizations for the sharded module. | ||
| for i in range(args.iter_nums): | ||
| # For TP, input needs to be same across all TP ranks. | ||
| # while for SP, input can be different across all ranks. | ||
| # Setting the random seed is to mimic the behavior of dataloader. | ||
| dp_rank = ( | ||
| rank | ||
| if args.run_seq_parallel | ||
| else dist.get_rank(device_mesh.get_dim_groups()[0]) | ||
| ) | ||
| torch.manual_seed(i + dp_rank) | ||
| inp = torch.rand(20, 10).cuda(rank) | ||
| output = model(inp) | ||
| output.sum().backward() | ||
| optimizer.step() | ||
| | ||
| cleanup() | ||
| | ||
| | ||
| if __name__ == "__main__": | ||
| n_gpus = torch.cuda.device_count() | ||
| parser = argparse.ArgumentParser() | ||
| # This is passed in via cmd | ||
| parser.add_argument("--world_size", type=int, default=n_gpus) | ||
| parser.add_argument("--iter_nums", type=int, default=10) | ||
| parser.add_argument("--run_seq_parallel", type=bool, default=False) | ||
| parser.add_argument("--tp_size", type=int, default=2) | ||
| args = parser.parse_args() | ||
| # The main entry point is called directly without using subprocess | ||
| if n_gpus < 4: | ||
| print("Requires at least 4 GPUs to run.") | ||
| elif not SP_AVAILABLE: | ||
| print( | ||
| "PyTorch doesn't have Sequence Parallelism available," | ||
| " need nightly build." | ||
| ) | ||
| else: | ||
| mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import argparse | ||
| import os | ||
| | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
| import torch.nn as nn | ||
| | ||
| | ||
| def setup(rank, world_size): | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "12355" | ||
| | ||
| # initialize the process group | ||
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
| torch.cuda.set_device(rank) | ||
| | ||
| | ||
| def cleanup(): | ||
| dist.destroy_process_group() | ||
| | ||
| | ||
| class ToyModel(nn.Module): | ||
| def __init__(self): | ||
| super(ToyModel, self).__init__() | ||
| self.net1 = nn.Linear(10, 32) | ||
| self.relu = nn.ReLU() | ||
| self.net2 = nn.Linear(32, 5) | ||
| | ||
| def forward(self, x): | ||
| return self.net2(self.relu(self.net1(x))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.