(四)关于loss的实现以及搭建网络过程中踩的一些坑

先说一下loss吧,在之前的代码中,我们的使用的是官方自带的计算损失函数,交叉熵损失函数。

criterion = torch.nn.CrossEntropyLoss()

交叉熵损失函数主要由三部分组成,softmax--->log---->nll_loss。具体的有时间在写篇文章介绍介绍。实现代码如下:我定义的是一个类,实际上定义一个函数也行。

class Compute_Loss(nn.Module):
    def __init__(self):
        super(Compute_Loss, self).__init__()

    def forward(self, pred, target):
        pred = pred.to(device)
        target = target.to(device)
        log_soft = F.log_softmax(pred, dim=1)
        loss = F.nll_loss(log_soft, target)
        return loss

踩的一些坑:

1.之前搭建的resnet18网络最后一层是fc全连接层,我在fc层后面加了一层softmax层之后发现loss一直降不下去,后面在网上查了一下再结合nn.crosentropy损失函数的构成发现:

  • 输入的每一类的置信度得分(input)应该是原始的,未经过softmax或者normalized。原因是这个函数会首先对输入的原始得分进行softmax,所以必须保证输入的是每一类的原始得分。而且输入的target不能是one_hot编码的形式。

2.同理我在FC层后面添加了一个Relu层,loss也是一样降不下去,可能是经过relu之后的数据在进行计算loss的时候会有影响。

3.搭建网络时,在forward里面直接使用nn.Linear()层,用GPU来训练的时候会报错,显示数据在GPU上而模型不在GPU上,尽管我将model.to("cuda")也没有用,所以需要在定义类的内置变量的时候定义nn.Linear(),然后再在forward里面调用。

自己定义loss函数,这里虽然是模仿nn.crossentropy(),但你也可以搭建一个自己的损失函数类来计算loss。

代码:

import torch
from torchvision import datasets, transforms, models
import os
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageFile
from my_resnet import MainNet
import torch.nn.functional as F
import torch.nn as nn
ImageFile.LOAD_TRUNCATED_IMAGES = True
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def train():
    running_loss = 0
    for batch_idx, (data, target) in enumerate(train_data):
        data, target = data.to(device), target.to(device)
        out = net(data)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    return running_loss


def test():
    correct, total = 0, 0
    with torch.no_grad():
        for _, (data, target) in enumerate(val_data):
            data, target = data.to(device), target.to(device)
            out = net(data)
            out = F.softmax(out, dim=1)
            prediction = out.argmax(dim=1)
            # prediction = torch.max(out.data, dim=1)[1]
            total += target.size(0)
            correct += (prediction == target).sum().item()
        print('Accuracy on test set: (%d/%d)=%d %%' % (correct, total, 100 * correct / total))


class Compute_Loss(nn.Module):
    def __init__(self):
        super(Compute_Loss, self).__init__()

    def forward(self, pred, target):
        pred = pred.to(device)
        target = target.to(device)
        log_soft = F.log_softmax(pred, dim=1)
        loss = F.nll_loss(log_soft, target)
        return loss




if __name__ == '__main__':
    loss_list = []
    Epoches = 200
    Batch_Size = 4
    Image_Size = [256, 256]

    # 1.数据加载
    data_dir = r'D:\Code\python\完整项目放置\classify_project\multi_classification\my_dataset1'
    # 1.1 定义要对数据进行的处理
    data_transform = {x: transforms.Compose([transforms.Resize(Image_Size), transforms.ToTensor()]) for x in
                      ["train", "valid"]}
    image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x]) for x in
                      ["train", "valid"]}
    dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], batch_size=Batch_Size, shuffle=True) for x in
                  ["train", "valid"]}
    train_data, val_data = dataloader["train"], dataloader["valid"]

    index_classes = image_datasets["train"].class_to_idx
    print(index_classes)
    example_classes = image_datasets["train"].classes
    print(example_classes)

    num_classes = 3
    net = MainNet(num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(device)

    # 5.定义损失函数,以及优化器
    LR = 0.0001
    criterion = Compute_Loss()
    optimizer = optim.Adam(net.parameters(), lr=LR)

    best_loss = 100
    for epoch in range(Epoches):
        loss = train()
        loss_list.append(loss)
        print("第%d轮的loss为:%5f:" % (epoch, loss))
        test()

        if loss < best_loss:
            best_loss = loss
            torch.save(net, "best1.pth")
        torch.save(net, "last1.pth")


    plt.title("Graph")
    plt.plot(range(Epoches), loss_list)
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.show()

猜你喜欢

转载自blog.csdn.net/m0_48095841/article/details/125751332