金属瑕疵检测(基于Pytorch框架)

****金属材料表面产生裂纹、起皮、划伤等瑕疵****

这些瑕疵会严重影响材料的质量。为保证产品质量,需要人工进行肉眼目测。然而,传统人工肉眼检查十分费力,不能及时准确的判断出表面瑕疵,质检的效率难以把控。
近年来,深度学习在图像识别等领域取得了突飞猛进的成果。钢铁铝型材制造商迫切希望采用最新的AI技术来革新现有质检流程,自动完成质检任务,减少漏检发生率。

瑕疵衡量标准
1.型材表面应整洁,不允许有裂纹、起皮、腐蚀和气泡等缺陷存在。
2.型材表面上允许有轻微的压坑、碰伤、擦伤存在,其允许深度装饰面≯0.03mm,非装饰面>0.07mm,模具挤压痕深度≯0.03mm。
3.型材端头允许有因锯切产生的局部变形,其纵向长度不应超过10mm。
4.工业生产过程中,不够明显的瑕疵也会被作为无瑕疵图片进行处理,不必拘泥于无瑕疵图片中的不够明显的瑕疵。

具体实现:

AlexNet

# -*- encoding: utf-8 -*-
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data.dataloader as dataloader
from PIL import Image
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from tqdm import tqdm
import torch.nn.functional as F


class AlNetDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.dataset = []  # (img, label)
        self.transform = transform
        self.train_or_test = "train" if is_train else "test"
        path = f"{
      
      root}//{
      
      self.train_or_test}"
        for label in os.listdir(path):
            for img_path in os.listdir(f"{
      
      path}//{
      
      label}"):
                self.dataset.append((f"{
      
      path}//{
      
      label}//{
      
      img_path}", label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = Image.open(data[0]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        one_hot = np.zeros(2)
        one_hot[int(data[1])] = 1
        return img, np.float32(one_hot)

    def remove_NoneType(self):
        for data in self.dataset:
            img = cv2.imread(data[0])
            if not isinstance(img, np.ndarray):
                print(data[0])
                os.remove(data[0])


class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1)
        self.conv2 = torch.nn.Conv2d(in_channels=in_channels // 2, out_channels=in_channels // 2, kernel_size=3,
                                     padding=1)
        self.conv3 = torch.nn.Conv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1)

    def forward(self, x):
        return F.relu(self.conv3(F.relu(self.conv2(F.relu(self.conv1(x))))) + x)


class AlexNet(nn.Module):
    """
    58289538个参数
    58308674 + BatchNorm2d
    59725826 + ResBlock
    """

    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # x的输入: [b,3,227,227]
            nn.Conv2d(3, 96, kernel_size=(11, 11), stride=4, padding=1), nn.BatchNorm2d(96), ResBlock(96),
            # [b,96,55,55]
            nn.ReLU(inplace=False),
            # [b,96,55,55]
            nn.MaxPool2d(kernel_size=(3, 3), stride=2),
            # [b,96,27,27]
            nn.Conv2d(96, 256, kernel_size=(5, 5), stride=1, padding=2), nn.BatchNorm2d(256), ResBlock(256),
            # [b,256,27,27]
            nn.ReLU(inplace=False),
            # [b,256,27,27]
            nn.MaxPool2d(kernel_size=(3, 3), stride=2),
            # [b,256,13,13]
            nn.Conv2d(256, 384, kernel_size=(3, 3), stride=1, padding=1), nn.BatchNorm2d(384), ResBlock(384),
            # [b,384,13,13]
            nn.ReLU(inplace=False),
            # [b,384,13,13]
            nn.Conv2d(384, 384, kernel_size=(3, 3), stride=1, padding=1), nn.BatchNorm2d(384), ResBlock(384),
            # [b,384,13,13]
            nn.ReLU(inplace=False),
            # [b,384,13,13]
            nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1), nn.BatchNorm2d(256), ResBlock(256),
            # [b,256,13,13]
            nn.ReLU(inplace=False),
            # [b,256,13,13]
            nn.MaxPool2d(kernel_size=(3, 3), stride=2),
            # [b,256,6,6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096), nn.BatchNorm1d(4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, 4096), nn.BatchNorm1d(4096),
            nn.ReLU(inplace=False),
            nn.Linear(4096, 2),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class Trainer:
    def __init__(self):
        tf = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.RandomVerticalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self.train_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=True, transform=tf)
        self.test_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=False,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
        self.train_loader = dataloader.DataLoader(dataset=self.train_set, batch_size=128, shuffle=True)
        self.test_loader = dataloader.DataLoader(dataset=self.test_set, batch_size=64, shuffle=True)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AlexNet().to(self.device)
        self.opt = optim.Adam(self.model.parameters())
        self.fc_loss = nn.CrossEntropyLoss()
        self.summerWriter = SummaryWriter("./logs")

    def train(self):
        best_cnn = 0
        for epoch in range(1, 10001):
            t = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
            sum_loss = 0.
            sum_score = 0
            for i, (images, labels) in t:
                images = images.to(self.device)
                labels = labels.to(self.device)
                output = self.model(images)
                loss = self.fc_loss(output, labels)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                sum_loss += loss.item()
                a = torch.argmax(output, dim=1)
                b = torch.argmax(labels, dim=1)
                score = torch.mean(torch.eq(a, b).float())
                sum_score += score.item()
                t.set_description(f'{
      
      epoch}/{
      
      1000}')
            torch.save(self.model.state_dict(), f"params/alex.pth")
            self.summerWriter.add_scalar('alex_train平均损失', sum_loss / len(self.train_loader), epoch)
            self.summerWriter.add_scalar('alex_train平均得分', sum_score / len(self.train_loader), epoch)
            sum_score = self.test(epoch, is_test=False)
            if sum_score > best_cnn:
                torch.save(self.model.state_dict(), f"params/best_alex.pth")
                best_cnn = sum_score

    def test(self, epoch, is_test):
        sum_score = 0.
        sum_loss = 0.
        if is_test:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\best_alex.pth"))
        else:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\alex.pth"))
        t = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
        for i, (img, target) in t:
            self.model.eval()
            img, target = img.to(self.device), target.to(self.device)
            out = self.model(img)
            loss = self.fc_loss(out, target)
            sum_loss += loss.item()
            a = torch.argmax(out, dim=1)
            b = torch.argmax(target, dim=1)
            score = torch.mean(torch.eq(a, b).float())
            sum_score += score.item()
        if not is_test:
            self.summerWriter.add_scalar('alex_test平均损失', sum_loss / len(self.test_loader),
                                         epoch)
            self.summerWriter.add_scalar('alex_test平均得分', sum_score / len(self.test_loader),
                                         epoch)
        print('')
        print('============================================================')
        print(f"第{
      
      epoch}轮, test平均得分: %s" % (sum_score / len(self.test_loader)))
        print(f"第{
      
      epoch}轮, test平均损失: %s" % (sum_loss / len(self.test_loader)))

        return sum_score


if __name__ == '__main__':
    trainer = Trainer()
    # trainer.train()
    trainer.test(epoch=1, is_test=True)
    # a = AlexNet()
    # print(a)
    # print(sum([item.numel() for item in a.parameters()]))

MobileNet_v2

# -*- encoding: utf-8 -*-
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data.dataloader as dataloader
from PIL import Image
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from tqdm import tqdm


class AlNetDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.dataset = []  # (img, label)
        self.transform = transform
        self.train_or_test = "train" if is_train else "test"
        path = f"{
      
      root}//{
      
      self.train_or_test}"
        for label in os.listdir(path):
            for img_path in os.listdir(f"{
      
      path}//{
      
      label}"):
                self.dataset.append((f"{
      
      path}//{
      
      label}//{
      
      img_path}", label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = Image.open(data[0]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        one_hot = np.zeros(2)
        one_hot[int(data[1])] = 1
        return img, np.float32(one_hot)

    def remove_NoneType(self):
        for data in self.dataset:
            img = cv2.imread(data[0])
            if not isinstance(img, np.ndarray):
                print(data[0])
                os.remove(data[0])


class MobileNetV2(nn.Module):
    """
    25957090个参数
    """

    def __init__(self, num_classes=2):
        super(MobileNetV2, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # 将1维直接展平
        x = self.classifier(x)
        return x


class Trainer:
    def __init__(self):
        tf = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.RandomVerticalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self.train_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=True, transform=tf)
        self.test_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=False,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
        self.train_loader = dataloader.DataLoader(dataset=self.train_set, batch_size=32, shuffle=True)
        self.test_loader = dataloader.DataLoader(dataset=self.test_set, batch_size=16, shuffle=True)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = MobileNetV2().to(self.device)
        self.opt = optim.Adam(self.model.parameters(), lr=0.0001)
        self.fc_loss = nn.CrossEntropyLoss()
        self.summerWriter = SummaryWriter("./logs")

    def train(self):
        best_cnn = 0
        for epoch in range(1, 1001):
            t = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
            sum_score = 0
            sum_loss = 0.
            for i, (images, labels) in t:
                images = images.to(self.device)
                labels = labels.to(self.device)
                output = self.model(images)
                loss = self.fc_loss(output, labels)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                sum_loss += loss.item()
                a = torch.argmax(output, dim=1)
                b = torch.argmax(labels, dim=1)
                score = torch.mean(torch.eq(a, b).float())
                sum_score += score.item()
                t.set_description(f'{
      
      epoch}/{
      
      1000}')
            torch.save(self.model.state_dict(), f"params/mobilenetv2.pth")
            self.summerWriter.add_scalar('mobilenetv2_train平均损失', sum_loss / len(self.train_loader), epoch)
            self.summerWriter.add_scalar('mobilenetv2_train平均得分', sum_score / len(self.train_loader), epoch)
            sum_score = self.test(epoch, is_test=False)
            if sum_score > best_cnn:
                torch.save(self.model.state_dict(), f"params/best_mobilenetv2.pth")
                best_cnn = sum_score

    def test(self, epoch, is_test):
        sum_score = 0.
        sum_loss = 0.
        if is_test:
            self.model.load_state_dict(
                torch.load(r"D:\documents\start_ai\projects\al_material\params\best_mobilenetv2.pth"))
        else:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\mobilenetv2.pth"))
        t = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
        for i, (img, target) in t:
            self.model.eval()
            img, target = img.to(self.device), target.to(self.device)
            out = self.model(img)
            loss = self.fc_loss(out, target)
            sum_loss += loss.item()
            a = torch.argmax(out, dim=1)
            b = torch.argmax(target, dim=1)
            score = torch.mean(torch.eq(a, b).float())
            sum_score += score.item()
        if not is_test:
            self.summerWriter.add_scalar('mobilenetv2_test平均损失', sum_loss / len(self.test_loader),
                                         epoch)
            self.summerWriter.add_scalar('mobilenetv2_test平均得分', sum_score / len(self.test_loader),
                                         epoch)
        print('')
        print('============================================================')
        print(f"第{
      
      epoch}轮, test平均得分: %s" % (sum_score / len(self.test_loader)))
        print(f"第{
      
      epoch}轮, test平均损失: %s" % (sum_loss / len(self.test_loader)))

        return sum_score


if __name__ == '__main__':
    trainer = Trainer()
    # trainer.train()
    trainer.test(epoch=1, is_test=True)
    # a = MobileNetV2()
    # print(a)
    # print(sum([item.numel() for item in a.parameters()]))
    # from torchvision.models import mobilenet_v2 # 3504872 参数

VGGNet-16

# -*- encoding: utf-8 -*-
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data.dataloader as dataloader
from PIL import Image
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
from tqdm import tqdm



class AlNetDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.dataset = []  # (img, label)
        self.transform = transform
        self.train_or_test = "train" if is_train else "test"
        path = f"{
      
      root}//{
      
      self.train_or_test}"
        for label in os.listdir(path):
            for img_path in os.listdir(f"{
      
      path}//{
      
      label}"):
                self.dataset.append((f"{
      
      path}//{
      
      label}//{
      
      img_path}", label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = Image.open(data[0]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        one_hot = np.zeros(2)
        one_hot[int(data[1])] = 1
        return img, np.float32(one_hot)

    def remove_NoneType(self):
        for data in self.dataset:
            img = cv2.imread(data[0])
            if not isinstance(img, np.ndarray):
                print(data[0])
                os.remove(data[0])


class VGGNet16(nn.Module):
    """
    134268738个参数
    """

    def __init__(self, num_classes=2):
        super(VGGNet16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class Trainer:
    def __init__(self):
        tf = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.RandomVerticalFlip(),  # 增加数据多样性, 降低过拟合风险
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        self.train_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=True, transform=tf)
        self.test_set = AlNetDataset(root=r'D:\documents\data\al_material_data', is_train=False,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
        # batch=128内存爆炸  64接近爆满 一轮时长01:22
        self.train_loader = dataloader.DataLoader(dataset=self.train_set, batch_size=32, shuffle=True)
        self.test_loader = dataloader.DataLoader(dataset=self.test_set, batch_size=16, shuffle=True)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = VGGNet16().to(self.device)
        self.opt = optim.Adam(self.model.parameters(), lr=0.0001)
        self.fc_loss = nn.CrossEntropyLoss()
        self.summerWriter = SummaryWriter("./logs")

    def train(self):
        best_cnn = 0
        for epoch in range(1, 1001):
            t = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
            sum_loss = 0.
            sum_score = 0
            for i, (images, labels) in t:
                images = images.to(self.device)
                labels = labels.to(self.device)
                output = self.model(images)
                loss = self.fc_loss(output, labels)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                sum_loss += loss.item()
                a = torch.argmax(output, dim=1)
                b = torch.argmax(labels, dim=1)
                score = torch.mean(torch.eq(a, b).float())
                sum_score += score.item()
                t.set_description(f'{
      
      epoch}/{
      
      1000}')
            torch.save(self.model.state_dict(), f"params/vgg.pth")
            self.summerWriter.add_scalar('vgg_train平均损失', sum_loss / len(self.train_loader), epoch)
            self.summerWriter.add_scalar('vgg_train平均得分', sum_score / len(self.train_loader), epoch)
            sum_score = self.test(epoch, is_test=False)
            if sum_score > best_cnn:
                torch.save(self.model.state_dict(), f"params/best_vgg.pth")
                best_cnn = sum_score

    def test(self, epoch, is_test):
        sum_score = 0.
        sum_loss = 0.
        if is_test:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\best_vgg.pth"))
        else:
            self.model.load_state_dict(torch.load(r"D:\documents\start_ai\projects\al_material\params\vgg.pth"))
        t = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
        for i, (img, target) in t:
            self.model.eval()
            img, target = img.to(self.device), target.to(self.device)
            out = self.model(img)
            loss = self.fc_loss(out, target)
            sum_loss += loss.item()
            a = torch.argmax(out, dim=1)
            b = torch.argmax(target, dim=1)
            score = torch.mean(torch.eq(a, b).float())
            sum_score += score.item()
        if not is_test:
            self.summerWriter.add_scalar('vgg_test平均损失', sum_loss / len(self.test_loader),
                                         epoch)
            self.summerWriter.add_scalar('vgg_test平均得分', sum_score / len(self.test_loader),
                                         epoch)
        print('')
        print('============================================================')
        print(f"第{
      
      epoch}轮, test平均得分: %s" % (sum_score / len(self.test_loader)))
        print(f"第{
      
      epoch}轮, test平均损失: %s" % (sum_loss / len(self.test_loader)))

        return sum_score


if __name__ == '__main__':
    trainer = Trainer()
    # trainer.train()
    trainer.test(epoch=1, is_test=True)
    # a = VGGNet16()
    # print(sum([item.numel() for item in a.parameters()]))

猜你喜欢

转载自blog.csdn.net/weixin_44659309/article/details/130441985