Оглавление
1. Вывести максимальное значение всего тензора
2. Проиндексируйте максимальное значение по указанному измерению.
3. Сравните и выведите максимальное значение двух тензоров.
1. Устройства хранения данных и модели разные
4. несоответствие типа тензора
3. Набор данных и загрузчик данных
1. Учебник по PyTorch
Есть много функций факела, примеры иллюстрируют использование torch.max
# 导入库函数
import torch
# 生成示例数据
x = torch.randn(4,5)
y = torch.randn(4,5)
z = torch.randn(4,5)
print(x)
print(y)
print(z)
1. Вывести максимальное значение всего тензора
m = torch.max(x)
print(m)
2. Проиндексируйте максимальное значение по указанному измерению.
# torch.max(input, dim, keepdim=False, *, out=None)
'''
input (Tensor) - 输入Tensor
dim (int) - 指定维度(0, 按列计算;1, 按行计算)
keepdim (bool) - 输出张量是否保持与输入张量有相同数量的维度
out (tuple,optional) - 结果张量
'''
# 1. 根据位置参数调用有参函数
m, idx = torch.max(x,0)
m, idx = torch.max(x,0,False)
# 2. 关键字参数调用有参函数
m, idx = torch.max(input=x,dim=0)
m,idx = torch.max(x,1,keepdim=True)
p = (m,idx)
torch.max(x,0,False,out=p)
print(m)
print(idx)
print(p[0])
print(p[1])
3. Сравните и выведите максимальное значение двух тензоров.
t = torch.max(x,y)
print(t)
2. Распространенные ошибки
1. Устройства хранения данных и модели разные
import torch
# 错误代码
model = torch.nn.Linear(5,1).to("cuda:0")
x = torch.Tensor([1,2,3,4,5]).to("cpu")
y = model(x)
# 修改后
x = torch.Tensor([1,2,3,4,5]).to("cuda:0")
y = model(x)
print(y.shape)
2. Переменные размеры разные
# 错误代码
x = torch.randn(4,5)
y= torch.randn(5,4)
z = x + y
# 修改后
y= y.transpose(0,1)
z = x + y
print(z.shape)
3. cuda перегрузка памяти
import torch
import torchvision.models as models
# 错误代码
resnet18 = models.resnet18().to("cuda:0") # Neural Networks for Image Recognition
data = torch.randn(2048,3,244,244) # Create fake data (512 images)
out = resnet18(data.to("cuda:0")) # Use Data as Input and Feed to Model
print(out.shape)
# 修改后
for d in data:
out = resnet18(d.to("cuda:0").unsqueeze(0))
print(out.shape)
4. несоответствие типа тензора
import torch.nn as nn
# 错误代码
L = nn.CrossEntropyLoss()
outs = torch.randn(5,5)
labels = torch.Tensor([1,2,3,4,0])
lossval = L(outs,labels) # Calculate CrossEntropyLoss between outs and labels
# 修改后
labels = labels.long()
lossval = L(outs,labels)
print(lossval)
3. Набор данных и загрузчик данных
Пример простого набора данных и загрузчика данных выглядит следующим образом:
# Dataset
dataset = "abcdefghijklmnopqrstuvwxyz"
# Dataloader
for datapoint in dataset:
print(datapoint)
Используйте библиотечные функции torch для настройки наборов данных и загрузчиков, упрощая изменение наборов данных и вызов сетей обучения данных.
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefghijklmnopqrstuvwxyz"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
Если вы хотите улучшить данные, вам нужно всего лишь изменить блок кода в пользовательском классе. В следующем примере строка «abcdefghijklmnopqrstuvwxyz» усиливается до 2 раз и преобразует усиленные данные из строчных символов в прописные.
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefghijklmnopqrstuvwxyz"
def __getitem__(self,idx): # if the index is idx, what will be the data?
if idx >= len(self.data): # if the index >= 26, return upper case letter
return self.data[idx%26].upper()
else: # if the index < 26, return lower case, return lower case letter
return self.data[idx]
def __len__(self): # What is the length of the dataset
return 2 * len(self.data) # The length is now twice as large
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
Ссылка на ссылку: https://colab.research.google.com/github/ga642381/ML2021-Spring/blob/main/Pytorch/Pytorch_Tutorial.ipynb