Pytorch学习中遇到的问题

Pytorch 教程地址:pytorch handbook

有时候github页面会加载不出来,Github加载不出来的解决办法:Github网站css加载不出来的处理方法

修改host以后,如果没有恢复正常,是需要刷新DNS缓存,告诉电脑我的hosts文件已经修改。Windows下刷新DNS缓存的方法:进入命令行,输入命令:ipconfig /flushdns


1.Pytorch-handbook 3.2 MNIST数据集手写数字识别

1.1定义Test部分:

pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标

torch.max()函数---解释及例子

a = torch.randn(3,3)
torch.max)(a,0) #返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)
torch.max(a,1) #返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)

在这里就是:

output.max(1, keepdim=True)--->返回每一行中最大的元素并返回索引,返回了两个数组

output.max(1, keepdim=True)[1] 就是取第二个数组,取索引数组。

1.2 数据集部分

batch_size=512, 训练集大小为60000。所以一共60000/512=117.18

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)
len(train_loader)

>>118

我们用batch_size等参数制作了一个train_loader, transform 把数据转换成torch 的tensor, transforms.Normalize((0.1307,), (0.3081,))是数据进行归一化,均值和方差是0.1307,0.3081是根据数据集算好的。

2.多GPU训练

2.1 单机多GPU  torch.nn.DataParalle

用torch.nn.DataParalle ,我们只要将我们自己的模型作为参数,直接传入即可

#使用内置的一个模型,我们这里以resnet50为例
model = torchvision.models.resnet50()
#模型使用多GPU
mdp = torch.nn.DataParallel(model)
mdp

2.2 torch.distributed

2.3 torch.utils.checkpoint

发布了10 篇原创文章 · 获赞 10 · 访问量 7509

猜你喜欢

转载自blog.csdn.net/qq_41647438/article/details/103080742