目录
一、Pytorch Tutorial
torch函数有很多,举例说明torch.max的用法
# 导入库函数
import torch
# 生成示例数据
x = torch.randn(4,5)
y = torch.randn(4,5)
z = torch.randn(4,5)
print(x)
print(y)
print(z)
1. 输出整个tensor的最大值
m = torch.max(x)
print(m)
2. 按指定维度索引最大值
# torch.max(input, dim, keepdim=False, *, out=None)
'''
input (Tensor) - 输入Tensor
dim (int) - 指定维度(0, 按列计算;1, 按行计算)
keepdim (bool) - 输出张量是否保持与输入张量有相同数量的维度
out (tuple,optional) - 结果张量
'''
# 1. 根据位置参数调用有参函数
m, idx = torch.max(x,0)
m, idx = torch.max(x,0,False)
# 2. 关键字参数调用有参函数
m, idx = torch.max(input=x,dim=0)
m,idx = torch.max(x,1,keepdim=True)
p = (m,idx)
torch.max(x,0,False,out=p)
print(m)
print(idx)
print(p[0])
print(p[1])
3. 比较并输出两个tensor中的最大值
t = torch.max(x,y)
print(t)
二、常见错误
1. 数据和模型存储设备不同
import torch
# 错误代码
model = torch.nn.Linear(5,1).to("cuda:0")
x = torch.Tensor([1,2,3,4,5]).to("cpu")
y = model(x)
# 修改后
x = torch.Tensor([1,2,3,4,5]).to("cuda:0")
y = model(x)
print(y.shape)
2. 变量维度不同
# 错误代码
x = torch.randn(4,5)
y= torch.randn(5,4)
z = x + y
# 修改后
y= y.transpose(0,1)
z = x + y
print(z.shape)
3. cuda内存过载
import torch
import torchvision.models as models
# 错误代码
resnet18 = models.resnet18().to("cuda:0") # Neural Networks for Image Recognition
data = torch.randn(2048,3,244,244) # Create fake data (512 images)
out = resnet18(data.to("cuda:0")) # Use Data as Input and Feed to Model
print(out.shape)
# 修改后
for d in data:
out = resnet18(d.to("cuda:0").unsqueeze(0))
print(out.shape)
4. tensor类型不匹配
import torch.nn as nn
# 错误代码
L = nn.CrossEntropyLoss()
outs = torch.randn(5,5)
labels = torch.Tensor([1,2,3,4,0])
lossval = L(outs,labels) # Calculate CrossEntropyLoss between outs and labels
# 修改后
labels = labels.long()
lossval = L(outs,labels)
print(lossval)
三、Dataset和Dataloader
示例简单的数据集和数据加载器,如下:
# Dataset
dataset = "abcdefghijklmnopqrstuvwxyz"
# Dataloader
for datapoint in dataset:
print(datapoint)
利用torch的库函数可以自定义数据集和加载器,方便修改数据集和调用数据训练网络
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefghijklmnopqrstuvwxyz"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
想增强数据时,仅用修改自定义类中的代码块即可,下面示例将"abcdefghijklmnopqrstuvwxyz"字符串扩增至2倍,并将扩增数据由小写变换成大写字符
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefghijklmnopqrstuvwxyz"
def __getitem__(self,idx): # if the index is idx, what will be the data?
if idx >= len(self.data): # if the index >= 26, return upper case letter
return self.data[idx%26].upper()
else: # if the index < 26, return lower case, return lower case letter
return self.data[idx]
def __len__(self): # What is the length of the dataset
return 2 * len(self.data) # The length is now twice as large
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 1)
for datapoint in dataloader:
print(datapoint)