[Pytorch] Cómo usar RunBuilder para separar la lista de parámetros para lograr un ajuste a gran escala

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from collections import OrderedDict
from collections import namedtuple
from itertools import product

import torchvision
import torchvision.transforms as transforms


class RunBuilder():
    @staticmethod
    def get_runs(params):
        Run = namedtuple("Run", params.keys())
        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))
        return runs


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = t.reshape(-1, 12 * 4 * 4)
        t = F.relu(self.fc1(t))

        t = F.relu(self.fc2(t))

        t = self.out(t)

        return t


def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()


# 设置Pytorch输出行宽
torch.set_printoptions(linewidth=120)
# 输出版本信息
print("torch version", torch.__version__)
print("torchvision version", torchvision.__version__)
# 设置设备信息
if torch.cuda.is_available() == True:
    device = "cuda"
else:
    device = "cpu"
# 输出设备信息
print("device", device)
# 数据集
train_set = torchvision.datasets.FashionMNIST(
    root="./data/FashionMNIST"
    , train=True
    , download=True
    , transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
# 参数列表
params = OrderedDict(
    lr=[0.01]
    , batch_size=[100, 1000]
    , epoch_num=[4]
)
# 建立模型,设置模型设备
network = Network().to(device)
# 清零参数组合数
run_count = 0
# 遍历参数组合
for run in RunBuilder.get_runs(params):
    # 参数组合数+1
    run_count += 1
    # 输出当前参数组合
    print("run", run_count, run)

    # 参数读取
    batch_size = run.batch_size
    lr = run.lr
    epoch_num = run.epoch_num
    # 数据加载器
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    # 优化器
    optimizer = optim.Adam(network.parameters(), lr=0.01)
    # 周期训练
    for epoch in range(epoch_num):
        # 结果清零
        total_loss = 0
        total_correct = 0
        # 批处理
        for batch in train_loader:
            # 读取数据
            images, labels = batch
            # 设置数据设备
            images = images.to(device)
            labels = labels.to(device)
            # 模型计算
            preds = network(images)
            # 损失函数
            loss = F.cross_entropy(preds, labels)
            # 反向传输
            optimizer.zero_grad()  # 清空梯度
            loss.backward()  # 计算梯度
            optimizer.step()  # 更新权重
            # 结果统计
            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)
        # 输出结果
        print("epoch:", epoch, "total_correct:", total_correct, "total_loss:", total_loss, "mean_loss",
              total_loss / batch_size)

Supongo que te gusta

Origin blog.csdn.net/weixin_66896881/article/details/128688311
Recomendado
Clasificación