兄弟萌,我咕里个咚今天又杀回来了,有几天时间可以不用驻场了,喜大普奔,终于可以在有网的地方码代码了,最近驻场也是又热又心累啊,抓紧这几天,再更新一点的新东西。
今天主要讲一下非监督学习,你可能要问了,什么是非监督学习,我的理解就是不会给样本标签的,它本质上是一个统计手段,在没有标签的数据里可以发现潜在的一些结构的一种训练方式。这个可以用来干什么,举个例子,在工业场景瑕疵检测的运用中,由于良品的数量远远高于不良品的数量,如果这个时候你要采用监督学习,那么收集样本的时间就多得吓人了,可能你样本还没有收集完全,产品都已经做完下线了,所以,你就挠头吧。于是,非监督学习就迎来了一片蓝海,但是即使是蓝海也要你的船能开才行,这里面也不乏调整,比如非监督学习的效果就不太好去评估。当然,我想了点办法,在工业领域也有所应用了,但是过杀有点高,大概5%左右,这5%的过杀,再通过后期的监督算法,其实可以解决很多问题了,这样工业上的大部分问题,都可以有所缓解了。我真棒。哈哈哈
1.非监督学习网络架构
先提供一下,我的非监督学习的网络架构,还是基于pytorch来写的。我给这个网络一个名称叫做
piercing eye。话不多说,上代码。
from torch import nn
import torch
class CBP(nn.Module):
"""
conv + batchnormal + prelu
"""
def __init__(self,inc,ouc):
super().__init__()
self.block1=nn.Sequential(
nn.Conv2d(inc,ouc,3,1,1),
nn.BatchNorm2d(ouc),
nn.PReLU()
)
def forward(self,y):
return self.block1(y)
class Up_Block(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.block1=nn.Sequential(
nn.ConvTranspose2d(in_channel, out_channel, 3, 2, 1, 1),
nn.BatchNorm2d(out_channel),
nn.PReLU()
)
def forward(self,y):
return self.block1(y)
class Down_Block(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.block1=nn.Sequential(
nn.Conv2d(in_channel, out_channel, 5, 2, padding=2),
nn.BatchNorm2d(out_channel),
nn.PReLU()
)
def forward(self,y):
return self.block1(y)
class PiercingEye(nn.Module):
def __init__(self):
super().__init__()
self.block1=nn.Sequential(
CBP(3, 4),
Down_Block(4, 8),
Down_Block(8, 16),
Down_Block(16, 32),
Down_Block(32, 64),
Down_Block(64, 128),
Down_Block(128, 256),
CBP(256, 256),
Up_Block(256, 128),
Up_Block(128, 64),
Up_Block(64, 32),
Up_Block(32, 16),
Up_Block(16, 8),
Up_Block(8, 4),
nn.Conv2d(4,3,1),
nn.Tanh()
)
def forward(self,y):
return self.block1(y)
if __name__ == '__main__':
net = PiercingEye()
x = torch.Tensor(2,3,512,512)
y = net(x)
print(y.shape)
简单的说明一下,其实就是做了6次下采样和6次上采样,也就是AE网络,中间没有任何跳跃连接,也可以理解成是一个生成网络。
2.数据集准备
我直接给大家一个百度云的链接,这也是一个开源的数据集,我稍微整理了一下,方便大家使用
链接:https://pan.baidu.com/s/1ir5xmYJWAX8QIXHb6_5zWw
提取码:5ph7
里面一共两个文件夹,data_train,data_val截图以示清白。
数据大概就是这个样子的,左边是ok的,右边是ng的,不是菊花,不是菊花,不是菊花,重要的事说三遍。
data_train里面一共784张图片,都是ok图片
data_val里面一共366张图片,150张ng图片和216张ok图片
我们训练只训练OK图片,看看能不能通过只训练OK图片来判断验证集里面的OK和NG
3.Dataset
有了网络,有了数据,就该来处理数据准备往网络里送了,上代码
import torch
import os
from torch.utils.data import Dataset
import random
import data_agumentation
import torchvision.transforms as tf
import cv2
transform = tf.Compose([tf.ToTensor(),tf.Normalize([0.5],[0.5])])
class train_data(Dataset):
def __init__(self, path):
print('start build_train_data')
self.path=path
self.imgs=[]
for i in os.listdir(path):
self.imgs.append(i)
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
random_num = random.randint(0,2)
img=cv2.imread(self.path+'/'+self.imgs[index])
if random_num == 1:
img = data_agumentation.augment_left_flip(img)
elif random_num == 2:
img = data_agumentation.augment_rotate(img,180)
img = transform(img)
return img,img
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data=train_data(r'D:\blog_project\guligedong_unsupervised\data\data_train')
train_loader = torch.utils.data.DataLoader(data, batch_size=3, shuffle=True)
for index,(img, label) in enumerate(train_loader):
print(img.size())
print(label.size())
print()
你应该很熟悉,因为上一篇也写过类似的,这你会看到,其实img和label是一样的, 我们的目的也是输入一张图片,让他生成一张一样的图片,这样是为什么呢?我的思路就是,因为我的训练集只有ok的,网络只能生成ok的特征,如果输入的是ng的图片,那么网络就不能生成ng的特征,这个时候就会有差异了。
def __getitem__(self, index):
random_num = random.randint(0,2)
img=cv2.imread(self.path+'/'+self.imgs[index])
if random_num == 1:
img = data_agumentation.augment_left_flip(img)
elif random_num == 2:
img = data_agumentation.augment_rotate(img,180)
img = transform(img)
return img,img
在这个代码段里,我用了随机的样本增强,这个样本增强也是我前面文章中提供给大家的,就是一个水平的镜像翻转和180度的旋转,当然你还可以增加90和270度的旋转,大家也可以看到,这个dataset就简单了很多,因为非监督学习没有标签,或者说,非监督学习样本的标签就是它本身。
4.训练代码
这就是一个常规的训练代码,你根据你的显卡情况,调整一下batch_size就可以把程序跑起来了
import torch
import torch.nn as nn
import os
import dataset
import numpy as np
import random
from net import *
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.set_printoptions(precision=4, suppress=True)
def seed_everything(seed=2117):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
seed_everything()
gst = torch.cuda.is_available() and int(torch.version.__version__.split(".")[1]) >= 5
if gst:
gs = torch.cuda.amp.GradScaler()
loss_f1 = nn.MSELoss()
loss_f2 = nn.L1Loss()
def train_once(opt, lrud, data_load, net,epoch,epoches):
losss = 0
ss = 0
iteration_num = len(data_load.dataset)
for s, (img, label) in enumerate(data_load):
# at=time.time()
img = img.to(device)
label = label.to(device)
# print(at)
img_ = net(img)
# print(time.time()-at)
loss1 = loss_f1(img_, label)
loss2 = loss_f2(img_, label)
loss = loss1 + loss2
opt.zero_grad()
if gst:
gs.scale(loss).backward()
gs.step(opt)
gs.update()
else:
loss.backward()
opt.step()
lrud.step()
losss += loss.item()
ss = ss + img.size(0)
# print(ss, '/', iteration_num, loss.item())
print('[epoch][%s/%s]----[batch][%s/%s]------------------loss=%.8f'%(epoch+1,epoches,s+1,len(data_load),loss.item()))
return losss
def train():
train_img_path = r'./data/data_train'
net_cache = r'./model'
if not os.path.exists(net_cache):
os.makedirs(net_cache)
net_path = net_cache + '/net.pth'
batch_size = 64
total_epochs = 1000
data = dataset.train_data(train_img_path)
if os.path.exists(net_path):
net = torch.load(net_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")).to(device)
else:
net = PiercingEye().to(device)
print('train_data_success!')
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
net.train()
opt = torch.optim.Adam(net.parameters())
lrud = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=5, eta_min=1e-6)
a = 0
loss_l = 100
for a in range(total_epochs):
losss = train_once(opt, lrud, train_loader, net,a,total_epochs)
if loss_l >= losss or a % 33 == 0:
loss_l = losss
torch.save(net, net_path)
# torch.save(net, net_cache + '/net_' + str(a) + '.pth')
print('save pth success')
print()
if __name__ == '__main__':
train()
跑起来之后就是这个样子的
5.推理代码
给个链接,自己看
https://blog.csdn.net/guligedong/article/details/120132264?spm=1001.2014.3001.5502
6.测试结果
奇怪的测试代码还没有写,居然有测试结果了,怎么可能,还我怎么可能没有写,写好了,只是讲起来这个推理逻辑就要大费周章了,所以先给大家看看结果(你可以先猜猜怎么做的,反正不是segmentation)
可以看到的是,模型效果也不是特别的优秀,92%的正确率,6%的过杀,还有1%的漏失,工业上过杀不可怕,可怕的就是漏失,我的推理代码里其实可以控制让漏失更小,但是相应的,过杀就会更高。具体场景看具体需求。这只是一个方法,毕竟非监督学习后面还可以加监督模型嘛,怕什么。
好了,今天就更新的到这里,明天星期六,我又要去驻场了。哎,不辛苦,命苦。
前两个这个文章到了人工智能榜的20名左右,我觉得还是不错的,感谢大家支持,没有截图,不信就算了,不信也得信。
至此,敬礼,salute!!!!