Metal defect detection (based on Pytorch framework)

****金属材料表面产生裂纹、起皮、划伤等瑕疵****

These imperfections can seriously affect the quality of the material. In order to ensure product quality, manual visual inspection is required. However, the traditional manual inspection with naked eyes is very laborious, and the surface defects cannot be judged in a timely and accurate manner, and the efficiency of quality inspection is difficult to control.
In recent years, deep learning has made rapid progress in areas such as image recognition. Manufacturers of steel and aluminum profiles are eager to adopt the latest AI technology to innovate the existing quality inspection process, automatically complete quality inspection tasks, and reduce the incidence of missed inspections.

Defect measurement standard
1. The surface of the profile should be clean and free of defects such as cracks, peeling, corrosion and air bubbles.
2. Slight pressure pits, bruises, and scratches are allowed on the surface of the profile. The allowable depth of the decorative surface is ≯0.03mm, the non-decorative surface is >0.07mm, and the depth of the mold extrusion mark is ≯0.03mm.
3. The end of the profile is allowed to have local deformation caused by sawing, and its longitudinal length should not exceed 10mm.
4. In the process of industrial production, the flaws that are not obvious enough will also be treated as flawless pictures, so there is no need to stick to the flaws that are not obvious enough in the flawless pictures.

Implementation:

AlexNet

# -*- encoding: utf-8 -*-
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data.dataloader as dataloader
from PIL import Image
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from tqdm import tqdm
import torch.nn.functional as F


class AlNetDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.dataset = []  # (img, label)
        self.transform = transform
        self.train_or_test = "train" if is_train else "test"
        path = f"{
      
      root}//{
      
      self.train_or_test}"
        for label in os.listdir(path):
            for img_path in os.listdir(f"{
      
      path}//{
      
      label}"):
                self.dataset.append((f"{
      
      path}//{
      
      label}//{
      
      img_path}", label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = Image.open(data[0]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        one_hot = np.zeros(2)
        one_hot[int(data[1])] = 1
        return img, np.float32(one_hot)

    def remove_NoneType(self):
        for data in self.dataset:
            img = cv2.imread(data[0])
            if not isinstance(img, np.ndarray):
                print(data[0])
                os.remove(data[0])


class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1)
        self.conv2 = torch.nn.Conv2d(in_channels=in_channels // 2, out_channels=in_channels // 2, kernel_size=3,
                                     padding=1)
        self.conv3 = torch.nn.Conv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1)

    def forward(self, x):
        return F.relu(self.conv3(F.relu(self.conv2(F.relu(self.conv1(x))))) + x)


class AlexNet(nn.Module):
    """
    58289538个参数
    58308674 + BatchNorm2d
    59725826 + ResBlock
    """

    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # x的输入: [b,3,227,227]
            nn.Conv2d(3, 96, kernel_size=(11, 11), stride=4, padding=1), nn.BatchNorm2d(96), ResBlock(96),
            # [b,96,55,55]
            nn.ReLU(inplace=False),
            # [b,96,55,55]
            nn.MaxPool2d(kernel_size=(3, 3), stride=2),
            # [b,96,27,27]
            nn.Conv2d(96, 256, kernel_size=(5, 5), stride=1, padding=2), nn.BatchNorm2d(256), ResBlock(256),
            # [b,256,27,27]
            nn.ReLU(inplace=False),
            # [b,256,27,27]
            nn.MaxPool2d(kernel_size=(3, 3), stride=2),
            # [b,256,13,13]
            nn.Conv2d(256, 384, kernel_size=(3, 3), stride=1, padding=1), nn.BatchNorm2d(384), ResBlock(384),
            # [b,384,13,13]
            nn.ReLU(inplace=False),
            # [b,384,13,13]
            nn.Conv2d(384, 384, kernel_size=(3, 3), stride=1, padding=1), nn.BatchNorm2d(384), ResBlock(384),
            # [b,384,13,13]
            nn.ReLU(inplace=False),
            # [b,384,13,13]
            nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1), nn.BatchNorm2d(256), ResBlock(256),
            # [b,256,13,13]
            nn.ReLU(inplace=False),
            # [b,256,13,13]
            nn.MaxPool2d(kernel_size=(3, 3), stride=2),
            # [b,256,6,6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096), nn.BatchNorm1d(4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, 4096), nn.BatchNorm1d(4096),
            nn.ReLU(inplace=False),
            nn.Linear(4096, 2),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class Trainer:
    def __init__(self):
        tf = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.RandomVerticalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self.train_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=True, transform=tf)
        self.test_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=False,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
        self.train_loader = dataloader.DataLoader(dataset=self.train_set, batch_size=128, shuffle=True)
        self.test_loader = dataloader.DataLoader(dataset=self.test_set, batch_size=64, shuffle=True)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AlexNet().to(self.device)
        self.opt = optim.Adam(self.model.parameters())
        self.fc_loss = nn.CrossEntropyLoss()
        self.summerWriter = SummaryWriter("./logs")

    def train(self):
        best_cnn = 0
        for epoch in range(1, 10001):
            t = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
            sum_loss = 0.
            sum_score = 0
            for i, (images, labels) in t:
                images = images.to(self.device)
                labels = labels.to(self.device)
                output = self.model(images)
                loss = self.fc_loss(output, labels)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                sum_loss += loss.item()
                a = torch.argmax(output, dim=1)
                b = torch.argmax(labels, dim=1)
                score = torch.mean(torch.eq(a, b).float())
                sum_score += score.item()
                t.set_description(f'{
      
      epoch}/{
      
      1000}')
            torch.save(self.model.state_dict(), f"params/alex.pth")
            self.summerWriter.add_scalar('alex_train平均损失', sum_loss / len(self.train_loader), epoch)
            self.summerWriter.add_scalar('alex_train平均得分', sum_score / len(self.train_loader), epoch)
            sum_score = self.test(epoch, is_test=False)
            if sum_score > best_cnn:
                torch.save(self.model.state_dict(), f"params/best_alex.pth")
                best_cnn = sum_score

    def test(self, epoch, is_test):
        sum_score = 0.
        sum_loss = 0.
        if is_test:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\best_alex.pth"))
        else:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\alex.pth"))
        t = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
        for i, (img, target) in t:
            self.model.eval()
            img, target = img.to(self.device), target.to(self.device)
            out = self.model(img)
            loss = self.fc_loss(out, target)
            sum_loss += loss.item()
            a = torch.argmax(out, dim=1)
            b = torch.argmax(target, dim=1)
            score = torch.mean(torch.eq(a, b).float())
            sum_score += score.item()
        if not is_test:
            self.summerWriter.add_scalar('alex_test平均损失', sum_loss / len(self.test_loader),
                                         epoch)
            self.summerWriter.add_scalar('alex_test平均得分', sum_score / len(self.test_loader),
                                         epoch)
        print('')
        print('============================================================')
        print(f"第{
      
      epoch}轮, test平均得分: %s" % (sum_score / len(self.test_loader)))
        print(f"第{
      
      epoch}轮, test平均损失: %s" % (sum_loss / len(self.test_loader)))

        return sum_score


if __name__ == '__main__':
    trainer = Trainer()
    # trainer.train()
    trainer.test(epoch=1, is_test=True)
    # a = AlexNet()
    # print(a)
    # print(sum([item.numel() for item in a.parameters()]))

MobileNet_v2

# -*- encoding: utf-8 -*-
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data.dataloader as dataloader
from PIL import Image
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from tqdm import tqdm


class AlNetDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.dataset = []  # (img, label)
        self.transform = transform
        self.train_or_test = "train" if is_train else "test"
        path = f"{
      
      root}//{
      
      self.train_or_test}"
        for label in os.listdir(path):
            for img_path in os.listdir(f"{
      
      path}//{
      
      label}"):
                self.dataset.append((f"{
      
      path}//{
      
      label}//{
      
      img_path}", label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = Image.open(data[0]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        one_hot = np.zeros(2)
        one_hot[int(data[1])] = 1
        return img, np.float32(one_hot)

    def remove_NoneType(self):
        for data in self.dataset:
            img = cv2.imread(data[0])
            if not isinstance(img, np.ndarray):
                print(data[0])
                os.remove(data[0])


class MobileNetV2(nn.Module):
    """
    25957090个参数
    """

    def __init__(self, num_classes=2):
        super(MobileNetV2, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # 将1维直接展平
        x = self.classifier(x)
        return x


class Trainer:
    def __init__(self):
        tf = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.RandomVerticalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self.train_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=True, transform=tf)
        self.test_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=False,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
        self.train_loader = dataloader.DataLoader(dataset=self.train_set, batch_size=32, shuffle=True)
        self.test_loader = dataloader.DataLoader(dataset=self.test_set, batch_size=16, shuffle=True)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = MobileNetV2().to(self.device)
        self.opt = optim.Adam(self.model.parameters(), lr=0.0001)
        self.fc_loss = nn.CrossEntropyLoss()
        self.summerWriter = SummaryWriter("./logs")

    def train(self):
        best_cnn = 0
        for epoch in range(1, 1001):
            t = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
            sum_score = 0
            sum_loss = 0.
            for i, (images, labels) in t:
                images = images.to(self.device)
                labels = labels.to(self.device)
                output = self.model(images)
                loss = self.fc_loss(output, labels)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                sum_loss += loss.item()
                a = torch.argmax(output, dim=1)
                b = torch.argmax(labels, dim=1)
                score = torch.mean(torch.eq(a, b).float())
                sum_score += score.item()
                t.set_description(f'{
      
      epoch}/{
      
      1000}')
            torch.save(self.model.state_dict(), f"params/mobilenetv2.pth")
            self.summerWriter.add_scalar('mobilenetv2_train平均损失', sum_loss / len(self.train_loader), epoch)
            self.summerWriter.add_scalar('mobilenetv2_train平均得分', sum_score / len(self.train_loader), epoch)
            sum_score = self.test(epoch, is_test=False)
            if sum_score > best_cnn:
                torch.save(self.model.state_dict(), f"params/best_mobilenetv2.pth")
                best_cnn = sum_score

    def test(self, epoch, is_test):
        sum_score = 0.
        sum_loss = 0.
        if is_test:
            self.model.load_state_dict(
                torch.load(r"D:\documents\start_ai\projects\al_material\params\best_mobilenetv2.pth"))
        else:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\mobilenetv2.pth"))
        t = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
        for i, (img, target) in t:
            self.model.eval()
            img, target = img.to(self.device), target.to(self.device)
            out = self.model(img)
            loss = self.fc_loss(out, target)
            sum_loss += loss.item()
            a = torch.argmax(out, dim=1)
            b = torch.argmax(target, dim=1)
            score = torch.mean(torch.eq(a, b).float())
            sum_score += score.item()
        if not is_test:
            self.summerWriter.add_scalar('mobilenetv2_test平均损失', sum_loss / len(self.test_loader),
                                         epoch)
            self.summerWriter.add_scalar('mobilenetv2_test平均得分', sum_score / len(self.test_loader),
                                         epoch)
        print('')
        print('============================================================')
        print(f"第{
      
      epoch}轮, test平均得分: %s" % (sum_score / len(self.test_loader)))
        print(f"第{
      
      epoch}轮, test平均损失: %s" % (sum_loss / len(self.test_loader)))

        return sum_score


if __name__ == '__main__':
    trainer = Trainer()
    # trainer.train()
    trainer.test(epoch=1, is_test=True)
    # a = MobileNetV2()
    # print(a)
    # print(sum([item.numel() for item in a.parameters()]))
    # from torchvision.models import mobilenet_v2 # 3504872 参数

VGGNet-16

# -*- encoding: utf-8 -*-
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data.dataloader as dataloader
from PIL import Image
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from tqdm import tqdm



class AlNetDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.dataset = []  # (img, label)
        self.transform = transform
        self.train_or_test = "train" if is_train else "test"
        path = f"{
      
      root}//{
      
      self.train_or_test}"
        for label in os.listdir(path):
            for img_path in os.listdir(f"{
      
      path}//{
      
      label}"):
                self.dataset.append((f"{
      
      path}//{
      
      label}//{
      
      img_path}", label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = Image.open(data[0]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        one_hot = np.zeros(2)
        one_hot[int(data[1])] = 1
        return img, np.float32(one_hot)

    def remove_NoneType(self):
        for data in self.dataset:
            img = cv2.imread(data[0])
            if not isinstance(img, np.ndarray):
                print(data[0])
                os.remove(data[0])


class VGGNet16(nn.Module):
    """
    134268738个参数
    """

    def __init__(self, num_classes=2):
        super(VGGNet16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class Trainer:
    def __init__(self):
        tf = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.RandomVerticalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self.train_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=True, transform=tf)
        self.test_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=False,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
        # batch=128内存爆炸  64接近爆满 一轮时长01:22
        self.train_loader = dataloader.DataLoader(dataset=self.train_set, batch_size=32, shuffle=True)
        self.test_loader = dataloader.DataLoader(dataset=self.test_set, batch_size=16, shuffle=True)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = VGGNet16().to(self.device)
        self.opt = optim.Adam(self.model.parameters(), lr=0.0001)
        self.fc_loss = nn.CrossEntropyLoss()
        self.summerWriter = SummaryWriter("./logs")

    def train(self):
        best_cnn = 0
        for epoch in range(1, 1001):
            t = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
            sum_loss = 0.
            sum_score = 0
            for i, (images, labels) in t:
                images = images.to(self.device)
                labels = labels.to(self.device)
                output = self.model(images)
                loss = self.fc_loss(output, labels)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                sum_loss += loss.item()
                a = torch.argmax(output, dim=1)
                b = torch.argmax(labels, dim=1)
                score = torch.mean(torch.eq(a, b).float())
                sum_score += score.item()
                t.set_description(f'{
      
      epoch}/{
      
      1000}')
            torch.save(self.model.state_dict(), f"params/vgg.pth")
            self.summerWriter.add_scalar('vgg_train平均损失', sum_loss / len(self.train_loader), epoch)
            self.summerWriter.add_scalar('vgg_train平均得分', sum_score / len(self.train_loader), epoch)
            sum_score = self.test(epoch, is_test=False)
            if sum_score > best_cnn:
                torch.save(self.model.state_dict(), f"params/best_vgg.pth")
                best_cnn = sum_score

    def test(self, epoch, is_test):
        sum_score = 0.
        sum_loss = 0.
        if is_test:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\best_vgg.pth"))
        else:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\vgg.pth"))
        t = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
        for i, (img, target) in t:
            self.model.eval()
            img, target = img.to(self.device), target.to(self.device)
            out = self.model(img)
            loss = self.fc_loss(out, target)
            sum_loss += loss.item()
            a = torch.argmax(out, dim=1)
            b = torch.argmax(target, dim=1)
            score = torch.mean(torch.eq(a, b).float())
            sum_score += score.item()
        if not is_test:
            self.summerWriter.add_scalar('vgg_test平均损失', sum_loss / len(self.test_loader),
                                         epoch)
            self.summerWriter.add_scalar('vgg_test平均得分', sum_score / len(self.test_loader),
                                         epoch)
        print('')
        print('============================================================')
        print(f"第{
      
      epoch}轮, test平均得分: %s" % (sum_score / len(self.test_loader)))
        print(f"第{
      
      epoch}轮, test平均损失: %s" % (sum_loss / len(self.test_loader)))

        return sum_score


if __name__ == '__main__':
    trainer = Trainer()
    # trainer.train()
    trainer.test(epoch=1, is_test=True)
    # a = VGGNet16()
    # print(sum([item.numel() for item in a.parameters()]))

Guess you like

Origin blog.csdn.net/weixin_44659309/article/details/130441985