20 lines of code to train a simple resnet with pytorch


Preface

Recently I have to judge the feasibility of a project. The project is a 7 classification problem, and the label distribution of the data is [500,85,37,58,116,8,19]

The data is extremely unbalanced, and the amount of data is not large enough. Fortunately, both the instructors and seniors are better, so I will first explore the feasibility of this project

I have uploaded the relevant code on GitHub: https://github.com/XinzeWu/ResNet


20 lines of code

import torch.nn as nn
import torchvision
import torch
net = torchvision.models.resnet101()
epochs = 1000
lr = 0.001
loss_fun = nn.CrossEntropyLoss()
opt_SGD = torch.optim.SGD(net.parameters(), lr=lr)
data = data.cuda()
label = label.cuda()
for epoch in range(epochs):
        running_loss = 0.0
        opt_SGD.zero_grad()
        pre = net(data.float())
        loss = loss_fun(pre,label.long().squeeze())
        loss.backward()
        opt_SGD.step()
        running_loss += loss.item()
        print("Epoch%03d: Training_loss = %.5f" % (epoch + 1, running_loss))

Not
many lines, not many lines,

Then let’s explain:

1. What is ResNet?

I just put a link to the great god

Portal-ResNet detailed

Two, training steps

1. Import the library

I use the pytorch framework, and ResNet uses 101 layers. It is best to use .cuda() for acceleration

The code is as follows (example):

import torch.nn as nn
import torchvision
import torch
# resnet默认图片输入大小为224*224*3
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.resnet = torchvision.models.resnet101(pretrained=False)  
        #pretrained=True代表使用预训练的模型
    def forward(self, x):
        x= self.resnet(x)
        return x

In fact, duck does not have to be so troublesome, the following code can also be used

import torch.nn as nn
import torchvision
import torch
resnet = torchvision.models.resnet101()

Why do you want to make a class?

Because the default input of ResNet is 3*224*224, a convolution operation may be needed to transform the data.

If you don’t care about the data processing and want to see the network training directly, please skip the second step

2. Processing data

The code is as follows (example):

Convert my picture to grayscale and unify the size to facilitate subsequent input

pic_path = your_path
save_path = your_path
for i in range(7):
    pic_name = os.listdir(pic_path + str(i))
    count = 0
    for j in pic_name:
        img = cv2.imread(pic_path + str(i) + "/" + j ,cv2.IMREAD_GRAYSCALE)
        img  = cv2.resize(img,(224,224),interpolation=cv2.INTER_AREA)
        cv2.imwrite(save_path+str(i)+"/" + "{}.png".format(count), img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
        count=count + 1

After that, I saved it directly into npy format

import os
import cv2
pic_path = your_path
data = []
label = []
for i in range(5):
    pic_name = os.listdir(pic_path + str(i))
    for name in pic_name:
        img = cv2.imread(pic_path + str(i)+ "/" + name)
        data.append(img)
        label.append(i)
data = np.array(data)
label = np.array(label)
np.save('data',data)
np.save("label",label)

3. Write a network training function

data = np.load("data.npy")
label = np.load("label.npy")
data = data.reshape((796,3,224,224))
label = label.reshape((796,1))
data = torch.from_numpy(data) #转成tensor才能输入网络训练
label = torch.from_numpy(label)

Train the network. Don’t use any strategy at the very beginning, and let him over-fit when it comes up. In this way, you can judge whether a network can be used.

If you can't learn good results by directly building a deep network, then you can consider:

  • give up
  • Change a project
  • Data enhancement
  • The network is more complex and the number of layers is deeper

Seriously, why put giving up first? Later said that
if your computer is strong enough, you can put a lot of data directly lost in, if not strong enough, it is recommended to use Huawei's cloud platform

Early training period:

epochs = 1000
lr = 0.001
loss_fun = nn.CrossEntropyLoss()
opt_SGD = torch.optim.SGD(net.parameters(), lr=lr)
data = data.cuda()
label = label.cuda()
for epoch in range(epochs):
        running_loss = 0.0
        opt_SGD.zero_grad()
        pre = net(data.float())
        loss = loss_fun(pre,label.long().squeeze())
        loss.backward()
        opt_SGD.step()
        running_loss += loss.item()
        print("Epoch%03d: Training_loss = %.5f" % (epoch + 1, running_loss))

Run screenshot

Initial screenshot
Wait until the loss drops to about 1, then

for epoch in range(epochs):
    if epoch > 10 and flag == 1:
        flag = 0
        opt_SGD = torch.optim.SGD(net.parameters(), lr=lr*0.001)
    running_loss = 0.0
    opt_SGD.zero_grad()
    pre = net(data.float())
    loss = loss_fun(pre,label.long().squeeze())
    loss.backward()
    opt_SGD.step()
    running_loss += loss.item()
    print("Epoch%03d: Training_loss = %.5f" % (epoch + 1, running_loss))

I added a learning rate reduction.
This learning rate can be expanded slightly: at
the beginning: lr = 0.001
loss=1. When there are more, it can be adjusted to: lr = lr * 0.1.
Wait until the loss is almost unchanged. , You can multiply by 0.1 and
wait until the multiplied is very small, and the loss is still very large, what should I do?
Note: knock on the blackboard! ! !
Can increase the learning rate, dampen, try again lowered
because the network at this time may be trained to a local optimum process, so turn up the learning rate can be optimized out, someone invented a set of cosine attenuation
but I like hand-tuned

Finally run the screenshot:
Insert picture description here


Three, summary

  • I ran on the Huawei Cloud platform for an afternoon, and the loss function was adjusted from 7 to 1.05.
  • According to my own experience, this project is a bit difficult
  • The data is the food of the neural network. The current food is a bit under-eated, under-eated, and insufficient
  • In the case of 5 classifications (I discarded the least two types of data), using the 101-layer deep residual network, it still fails to converge
  • So what can be done:
    • Data enhancement, using rotation, transparency adjustment, translation, etc.
    • Expand the data set to make the data distribution more even

Guess you like

Origin blog.csdn.net/qq_44647796/article/details/109339214