【Pytorch深度学习50篇】·······第三篇:【非监督学习】【推理篇】

各位朋友们,大家好,相信大家看了之前的训练篇的文章,应该还是想知道,推理的具体过程,是通过什么来判断OK,NG的。那我们今天这个文章就来坎一坎。

1.具体算法流程图

大家看到这里可能就会有个疑问了,你的阈值图是个什么玩意,干嘛的,凭空产生的?

2.阈值图

好好,我先来讲讲阈值图怎么得到的。首先我们已经训练好了一个非监督的模型,我们也有了刚刚用于训练的一大批训练图片,我们就是用这一批训练图片来得到的这个阈值图。我们直接上代码,我通过代码来给大家解释一下怎么回事。

import torch
import torch.nn as nn
import os
import torchvision.transforms as tf
import cv2
import numpy as np

transform = tf.Compose([tf.ToTensor(),tf.Normalize([0.5],[0.5])])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def ds(dif,kernerl,stride):
    x=nn.AvgPool2d(kernerl,stride)
    return x(dif)

def im2dif(out,img,kernel,stride):
    difference=out-img
    difference[difference<0]=-difference[difference<0]
    difference=ds(difference,kernel,stride)
    return difference.sum(1)

def save_tensor(tensor,name):
    tensor=tensor.detach().cpu().clone().numpy()
    np.save(name,tensor)

def bthv():
    ok_img_path=r'D:\blog_project\guligedong_unsupervised\data\data_train'
    thv_path='./data/thv'
    use_model='./model/net_0905.pth'
    img_size=[512,512]
    
    net=torch.load(use_model,map_location=('cuda' if torch.cuda.is_available() else 'cpu')).to(device)
    net.eval()

    kernerl = 16
    stride = 8
    size=list((size-stride)//stride for size in img_size)
    dif=torch.zeros(size).to(device)
    for i in os.listdir(ok_img_path):
        img=cv2.imread(ok_img_path+'/'+i)
        img=transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            out=net(img).to(device)
        dif_=im2dif(out,img,kernerl,stride).squeeze(0)
        dif[dif<dif_]=dif_[dif<dif_]
        print('~',end='')
    save_tensor(dif,thv_path)
        


if __name__ == '__main__':
    bthv()

先来看bthv这个函数

def bthv():
    ok_img_path=r'D:\blog_project\guligedong_unsupervised\data\data_train'
    thv_path='./data/thv'
    use_model='./model/net_0905.pth'
    img_size=[512,512]
    
    net=torch.load(use_model,map_location=('cuda' if torch.cuda.is_available() else 'cpu')).to(device)
    net.eval()

    kernerl = 16
    stride = 8
    size=list((size-stride)//stride for size in img_size)
    dif=torch.zeros(size).to(device)
    for i in os.listdir(ok_img_path):
        img=cv2.imread(ok_img_path+'/'+i)
        img=transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            out=net(img).to(device)
        dif_=im2dif(out,img,kernerl,stride).squeeze(0)
        dif[dif<dif_]=dif_[dif<dif_]
        print('~',end='')
    save_tensor(dif,thv_path)

前面四个参数分别定义了训练图片的路径、阈值文件的路径、模型的路径以及图片的长宽

我们先加载模型

再创建了两个变量,kernel和stride,这个就是average polling的kernel_size和stride。

又创建了一个dif变量,它是一个全零的tensor,它的维度是根据你的kernel_size和stride来的,目的是为了和池化后的图的维度保持一致。

然后就开始for循环了,for循环里面是就是先读图,然后送进网络,得到生成的图片,然后把生成图片和原图片相减得到差值图,然后对差值图做平均池化,得到池化图dif_,然后关键的两行代码

dif_=im2dif(out,img,kernerl,stride).squeeze(0)
dif[dif<dif_]=dif_[dif<dif_]

其中im2dif是自定义的一个函数

def im2dif(out,img,kernel,stride):
    difference=out-img
    difference[difference<0]=-difference[difference<0]
    difference=ds(difference,kernel,stride)
    return difference.sum(1)

原理就是原图减去生成图得到差值图,差值图小于的0数都变成正数,然后对都变成正数的差值图做平均池化,得到的差值图再在channel维度求一个sum,得到一张单通道的池化图。

dif[dif<dif_]=dif_[dif<dif_]

这个代码的意思就是让刚刚创建的全零的tensor和池化图做比较,比零大的值就付给dif这个变量,当所有图都跑完之后,dif就相当于一个保存最大差异值的一个tensor。

最后我们再把这个tensor保存成一个npy文件,这也就是我们所谓的阈值图

3.真正的推理阶段

先上一下代码吧

import torch
import os
from net import *
import cv2
import numpy as np
import torchvision.transforms as tf

transform = tf.Compose([tf.ToTensor(),tf.Normalize([0.5],[0.5])])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_tensor(name):
    arr=np.load(name+'.npy')
    return torch.tensor(arr)

def t2i(tensor,Normalize=True):
    if Normalize:
        img=((tensor.cpu().permute(1,2,0)*0.5+0.5)*255).int()
    else:
        img=(tensor.cpu().permute(1,2,0)*255).int()
    return img.numpy()

def ds(dif,kernerl,stride):
    x=nn.AvgPool2d(kernerl,stride)
    return x(dif)

def im2dif(out,img,kernel,stride):
    difference=out-img
    difference[difference<0]=-difference[difference<0]
    difference=ds(difference,kernel,stride)
    return difference.sum(1)

def detect_once(net, img):
    img_size=[512,512]
    origin_size=[512,512]
    thv_path='./data/thv'
    
    img = cv2.resize(img,(img_size[1],img_size[0]))
    img_ = transform(img).unsqueeze(0).to(device)
    thv=read_tensor(thv_path).to(device)
    with torch.no_grad():
        out=net(img_)
        out_image = t2i(out[0])
        out_image = np.array(out_image,dtype=np.uint8())
    dif=im2dif(out,img_,16,8).squeeze(0)
    dif_image = t2i(dif.unsqueeze(0),False)
    dif_image = np.array(dif_image, dtype=np.uint8())
    dif_image = cv2.resize(dif_image,(512,512))

    z=torch.zeros(dif.size()).to(device)
    z[dif>(thv*1.05)]=1
    z_image = t2i(z.unsqueeze(0), False)
    z_image = np.array(z_image, dtype=np.uint8())
    z_image = cv2.resize(z_image, (512, 512))

    bar=z.sum().item()
    print(bar)
    flag = 'ok'
    if bar<=5:
        bar=round(1-torch.randint(0,40,[1]).item()/100)
        cv2.putText(img,'OK ', (0,img_size[1]-10),cv2.FONT_HERSHEY_SIMPLEX,5,(0,0,255),10)
    else:
        bar=round(1-torch.sigmoid(torch.tensor(bar)).item(),3)
        cv2.putText(img,'NG ', (0,img_size[1]-10),cv2.FONT_HERSHEY_SIMPLEX,5,(0,0,255),10)
        flag = 'ng'
    print(flag)
    h_image1 = cv2.hconcat((img,out_image))
    h_image2 = cv2.hconcat((dif_image,z_image))
    final_image = cv2.vconcat((cv2.cvtColor(h_image1,cv2.COLOR_BGR2GRAY),h_image2))

    return cv2.resize(img,(origin_size[1],origin_size[0])),bar,flag,final_image

def detect():
    im_path=r'./data/data_val'
    out_path=r'./data/output'
    use_model=r'./model/net_0905.pth'
    correct = 0
    overkill = 0
    escape = 0
    
    net = torch.load(use_model,map_location='cuda' if torch.cuda.is_available() else 'cpu').to(device)
    net.eval()
    
    for t,i in enumerate(os.listdir(im_path)):
        print(i)
        img=cv2.imread(im_path+'/'+i)
        img,bar,flag,final_image = detect_once(net, img)
        label = i.split('.')[0]
        if flag in label:
            correct += 1
        else:
            if flag == 'ok':
                escape += 1
            else:
                overkill += 1
        cv2.imwrite(out_path+'/'+i.split('.')[0]+'.png',final_image)
    print('正确率:',correct/len(os.listdir(im_path)))
    print('过杀率:',overkill/len(os.listdir(im_path)))
    print('漏失率:',escape/len(os.listdir(im_path)))

这个推理的过成其实和刚刚得到阈值的过程差不多了,区别就是我们用得到的池化图和阈值图比较,如果比阈值图大,那我们就认为它是一个异常点了,最后判断ok、ng是通过异常点的个数决定。大家可以先看看这个最后的这个脚本。看看能不能理解,我就不详细来说明了。

4.整个项目的代码和图片

直接上百度云链接

链接:https://pan.baidu.com/s/1gFqby0jW3pZ9ynRJ0zCa5g 
提取码:chws 
 

至此,敬礼,salute!!!!

猜你喜欢

转载自blog.csdn.net/guligedong/article/details/120132264
今日推荐