Video and AI, interacting with the process (2) pytorch minimally trains its own data set and recognizes it

target learning task

Detect the classification of the segmented image

2 using pytorch

pytorch is very simple to do training and loading

2.1 Prepare data

insert image description here
As shown in the figure above, the files used for training are placed in train, the files for verification are placed in val, train.txt and val.txt respectively put the file name and classification category, and then we just write the name in the code

In it, I put two kinds of files just to make an example, one is Cayenne Porsche, the other is engineering vehicle, as shown in the figure below
insert image description here
train.txt and
insert image description here
val.txt as shown in the figure below

3 show me the code

3.1 Loading data classes

Add a loaddata.py file

import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):
    def __init__(self, root, datatxt, transform=None, target_transform=None):
        super(LoadData, self).__init__()
        file_txt = open(datatxt,'r')
        imgs = []
        for line in file_txt:
            line = line.rstrip()
            words = line.split('|')
            imgs.append((words[0], words[1]))

        self.imgs = imgs
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        random.shuffle(self.imgs)
        name, label = self.imgs[index]
        img = Image.open(self.root + name).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs)

The LoadData class is inherited from torch.util.data.Dataset and requires a transform class input, which is actually the transformation size

3.2 Network class

Define a network class with only two outputs

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d((2, 2))
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(36*36*32, 120)
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool1(F.relu(self.conv2(x)))
        x = x.view(-1, 36*36*32)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.3 Main process

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Net

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


classes = ['工程车','卡宴']
transform = transforms.Compose(
   [transforms.Resize((152, 152)),transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',
                 datatxt='./data/'+'train.txt',
                 transform=transform)
test_data=LoadData(root ='./data/val/',
                datatxt='./data/'+'val.txt',
                transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)

def imshow(img):
   img = img / 2 + 0.5     # unnormalize
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()


net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
   running_loss = 0.0
   for i, data in enumerate(train_loader, 0):
       inputs, labels = data
       optimizer.zero_grad()
       outputs = net(inputs)
       loss = criterion(outputs, labels)
       loss.backward()
       optimizer.step()

       running_loss += loss.item()
       if i % 200 == 0:
           print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 200))
           running_loss = 0.0

print('Finished Training')

PATH = './test.pth'
torch.save(net.state_dict(), PATH)

net = Net()
net.load_state_dict(torch.load(PATH))

correct = 0
total = 0
with torch.no_grad():
   for data in test_loader:
       images, labels = data
       outputs = net(images)
       _, predicted = torch.max(outputs.data, 1)
       total += labels.size(0)
       correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
   100 * correct / total))

insert image description here
As shown in the figure above, when the epoch is 5, the accuracy is 80%, and when the epoch is 10, the accuracy is 100%. Don’t take it seriously, this is the data set in the training set for recognition, not the real accuracy.

3.4 Identification code

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import Net

PATH = './test.pth'
transform = transforms.Compose(
    [transforms.Resize((152, 152)),transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



net = Net()
net.load_state_dict(torch.load(PATH))

img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():
    outputs = net(img)
    _, predicted = torch.max(outputs.data, 1)
    print("the 102 img lable is ",predicted)

As shown in the figure below, 102 is recognized as 1 for Cayenne, which is correct
insert image description here

postscript

Later, we are going to pass the image from the video for classification, and use our tool VT to decode the video and share the memory to generate the image instead of loading it from the disk. To use our c++ decoding tool, interact with pytorch.
The following is the first article: Video and AI, interacting with the process (1)
VT tool is ready to open source, and it will be released after the Dragon Boat Festival

Guess you like

Origin blog.csdn.net/qianbo042311/article/details/131345310