AI开发平台MODELARTS-多机多卡数据并行-DistributedDataParallel(DDP):代码改造点

时间:2024-08-26 19:27:18

代码改造点

  • 引入多进程启动机制:初始化进程
  • 引入几个变量:tcp协议,rank进程序号,worldsize开启的进程数量
  • 分发数据:DataLoader中多了一个Sampler参数,避免不同进程数据重复
  • 模型分发:DistributedDataParallel(model)
  • 模型保存:在序号为0的进程下保存模型
import torch
class Net(torch.nn.Module):
	pass

model = Net().cuda()

### DistributedDataParallel Begin ###
model = torch.nn.parallel.DistributedDataParallel(Net().cuda())
### DistributedDataParallel End ###
support.huaweicloud.com/develop-modelarts/modelarts-distributed-0008.html