Pytorch 多GPU运行

self.net = netword()
n_gpu = 1
if n_gpu==1:
    self.net = torch.nn.DataParallel(self.net).cuda(device=0)
else:
    gpus = []
    for i in range(n_gpu):
    	gpus.append(i)
    	self.net = torch.nn.DataParallel(self.net, device_ids=gpus).cuda()

猜你喜欢

转载自blog.csdn.net/luolinll1212/article/details/85235549