import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import OrderedDict
from collections import namedtuple
from itertools import product
import torchvision
import torchvision.transforms as transforms
class RunBuilder():
@staticmethod
def get_runs(params):
Run = namedtuple("Run", params.keys())
runs = []
for v in product(*params.values()):
runs.append(Run(*v))
return runs
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t = F.relu(self.conv1(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = t.reshape(-1, 12 * 4 * 4)
t = F.relu(self.fc1(t))
t = F.relu(self.fc2(t))
t = self.out(t)
return t
def get_num_correct(preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item()
# 设置Pytorch输出行宽
torch.set_printoptions(linewidth=120)
# 输出版本信息
print("torch version", torch.__version__)
print("torchvision version", torchvision.__version__)
# 设置设备信息
if torch.cuda.is_available() == True:
device = "cuda"
else:
device = "cpu"
# 输出设备信息
print("device", device)
# 数据集
train_set = torchvision.datasets.FashionMNIST(
root="./data/FashionMNIST"
, train=True
, download=True
, transform=transforms.Compose([
transforms.ToTensor()
])
)
# 参数列表
params = OrderedDict(
lr=[0.01]
, batch_size=[100, 1000]
, epoch_num=[4]
)
# 建立模型,设置模型设备
network = Network().to(device)
# 清零参数组合数
run_count = 0
# 遍历参数组合
for run in RunBuilder.get_runs(params):
# 参数组合数+1
run_count += 1
# 输出当前参数组合
print("run", run_count, run)
# 参数读取
batch_size = run.batch_size
lr = run.lr
epoch_num = run.epoch_num
# 数据加载器
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
# 优化器
optimizer = optim.Adam(network.parameters(), lr=0.01)
# 周期训练
for epoch in range(epoch_num):
# 结果清零
total_loss = 0
total_correct = 0
# 批处理
for batch in train_loader:
# 读取数据
images, labels = batch
# 设置数据设备
images = images.to(device)
labels = labels.to(device)
# 模型计算
preds = network(images)
# 损失函数
loss = F.cross_entropy(preds, labels)
# 反向传输
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optimizer.step() # 更新权重
# 结果统计
total_loss += loss.item()
total_correct += get_num_correct(preds, labels)
# 输出结果
print("epoch:", epoch, "total_correct:", total_correct, "total_loss:", total_loss, "mean_loss",
total_loss / batch_size)
[Pytorch] Cómo usar RunBuilder para separar la lista de parámetros para lograr un ajuste a gran escala
Supongo que te gusta
Origin blog.csdn.net/weixin_66896881/article/details/128688311
Recomendado
Clasificación