在Ubuntu上进行PyTorch的分布式训练,你需要遵循以下步骤:
安装PyTorch: 确保你已经安装了PyTorch。你可以从PyTorch官网根据你的系统配置选择合适的安装命令。
准备环境: 在开始分布式训练之前,确保所有参与训练的节点都能够通过网络互相访问,并且可以SSH无密码登录。
设置环境变量: 为了启动分布式训练,你需要设置一些环境变量,例如WORLD_SIZE
(参与训练的总进程数)、RANK
(当前进程的排名)和MASTER_ADDR
(主节点的IP地址)等。
编写分布式训练脚本: PyTorch提供了torch.distributed
包来支持分布式训练。你需要在训练脚本中使用这个包来初始化分布式环境,并根据rank来分配不同的数据给每个进程。
启动分布式训练: 使用mpirun
或torch.distributed.launch
来启动分布式训练。例如,如果你使用的是mpirun
,命令可能如下所示:
mpirun --nproc_per_node=NUM_GPUS_YOU_HAVE -np WORLD_SIZE python -m torch.distributed.launch YOUR_TRAINING_SCRIPT.py
其中NUM_GPUS_YOU_HAVE
是每个节点上的GPU数量,WORLD_SIZE
是总的进程数,YOUR_TRAINING_SCRIPT.py
是你的训练脚本。
运行训练: 启动上述命令后,每个进程将会在不同的节点上运行,并且会自动连接到主节点开始分布式训练。
下面是一个简单的分布式训练脚本示例:
import torch import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist def main(rank, world_size): # 初始化进程组 dist.init_process_group(backend='nccl', init_method='tcp://<master_ip>:<master_port>', world_size=world_size, rank=rank) # 创建模型并将其移动到对应的GPU model = ... # 定义你的模型 model.cuda(rank) model = DDP(model, device_ids=[rank]) # 创建损失函数和优化器 criterion = nn.CrossEntropyLoss().cuda(rank) optimizer = optim.SGD(model.parameters(), lr=0.01) # 加载数据 dataset = ... # 定义你的数据集 sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = torch.utils.data.DataLoader(dataset, batch_size=..., sampler=sampler) # 训练模型 for epoch in range(...): # 定义epoch的数量 sampler.set_epoch(epoch) for data, target in dataloader: data, target = data.cuda(rank), target.cuda(rank) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() # 清理 dist.destroy_process_group() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--world-size', default=-1, type=int, help='number of processes participating in the job') parser.add_argument('--rank', default=-1, type=int, help='rank of the process') args = parser.parse_args() main(args.rank, args.world_size)
请注意,这只是一个基本的示例,实际的分布式训练脚本可能需要更多的配置和优化。此外,你还需要确保所有节点上的时间同步,以及正确配置防火墙规则以允许分布式训练所需的端口通信。