AI开发平台MODELARTS-创建单机多卡的分布式训练(DataParallel):代码改造点

时间:2024-08-16 20:39:17

代码改造点

模型分发:DataParallel(model)

完整代码由于代码变动较少,此处进行简略介绍。

import torch
class Net(torch.nn.Module):
	pass

model = Net().cuda()

### DataParallel Begin ###
model = torch.nn.DataParallel(Net().cuda())
### DataParallel End ###
support.huaweicloud.com/usermanual-standard-modelarts/modelarts-distributed-0007.html