【Pytorch深度学习50篇】·······第四篇:【Segmentation】【2】----- Deeplab V3+的数据准备和训练

上一篇文章已经把模型文件DeeplabV3+的代码放出来了,模型有了,我们就应该来准备数据了,本次的数据集还是之前人脸识别的那个数据集,我从里面随件挑选了40+的图片来标注训练。一共标注了两个类别,分别是【眼睛,嘴巴】,不得不说的是,标注真是一个累人的活,怪不得叫人工智能,智能的前提就是大量的人工,不过还好,我只标注了40张,而且还可爱美丽的美女数据,想到之前有人跟我说,他们之前的标注是····,我就吐了。

3.数据准备

好了,先来展示一下数据集吧(虽然只有40+,‘集’这个字有点勉强)。

这是原图

 这是标注后产生的图片

我用的标注软件是‘精灵标注助手’(我可没有打广告啊,毕竟我也没拿它一分钱)

 当这些准备工作都做完后,就到了我们的代码时间,注入灵魂的coding,上代码:

import os

from torch.utils.data import Dataset
import torch
import torchvision.transforms as transform
import cv2
import numpy as np

import config

tf = transform.Compose([transform.ToTensor(), transform.Normalize([0.5], [0.5])])

class My_Dataset(Dataset):
    def __init__(self,data_path):
        self.img_and_label = {}
        self.class_name = config.class_name
        self.attachments = os.listdir(os.path.join(data_path,'outputs','attachments'))
        self.imgs = []
        for i in os.listdir(data_path):
            if os.path.isfile(os.path.join(data_path,i)):
                self.imgs.append(os.path.join(data_path,i))
                label_list = []
                for j in self.attachments:
                    if j.split('_')[0] == i.split('.')[0]:
                        label_list.append(os.path.join(os.path.join(data_path,'outputs','attachments'),j))
                        self.img_and_label[os.path.join(data_path,i)]  = label_list
        print(self.imgs)
        print(self.img_and_label)
        print('数据准备结束')

    def __len__(self):
        return len(self.img_and_label)

    def __getitem__(self, index):
        img_path = self.imgs[index]
        label_path = self.img_and_label[self.imgs[index]]

        img = cv2.imread(img_path)
        img = cv2.resize(img,tuple(config.train_img_size))
        img = tf(img)

        label = torch.zeros(len(self.class_name),*config.train_img_size)
        for index,i in enumerate(label_path):
            label_cv = cv2.imread(i)
            label_cv = cv2.resize(label_cv,tuple(config.train_img_size))
            label_gray = cv2.cvtColor(label_cv,cv2.COLOR_BGR2GRAY)
            _,label_bi = cv2.threshold(label_gray,100 ,255,cv2.THRESH_BINARY)
            label_bi = torch.tensor(label_bi / 255)
            label[index,:,:] = label_bi

        return img,label


if __name__ == '__main__':
    data_path = r'D:\blog_project\guligedong_segmentation\DATA'
    my_data = My_Dataset(data_path)
    train_loader = torch.utils.data.DataLoader(my_data,batch_size=1,num_workers=0,shuffle=True)
    for i,j in train_loader:
        print(i.shape)
        print(j.shape)

 简单说明一下,My_Dataset的初始化函数,执行的是将原图和它对应的标签对应起来

__getitem__函数实现的功能是返回图片的tensor和标签的tensor。画个图说明一下

 原图要经过resize操作,再ToTensor

标签要先resize再灰度化再二值化再除以255,再转换成Tensor,因为我的数据只有两个类别,所以标签的tensor只有两个channel。

4.训练

训练的话,我感觉就是一系列的常规操作了,没必要再过多的去讲解了,这两我用了三个损失函数,分别是Diceloss,MSE,BCE,没有设置系数,直接加到一起的,大家可以看看代码

import torch
import os
import random
import numpy as np
from deeplabv3plus import *
import dataset as ds
import time
import torch.nn as nn
import config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.set_printoptions(precision=4, suppress=True)


def seed_everything(seed=2117):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything()


class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, out, lab):
        if torch.sum(out).item() == 0 and torch.sum(lab).item() == 0:
            return 1
        else:
            n = torch.sum(out * lab) * 2
            m = torch.sum(out) + torch.sum(lab)
            return 1 - n / m


loss_f1 = nn.MSELoss()
loss_f2 = nn.BCELoss()
loss_f3 = DiceLoss()


class Train:
    def __init__(self, data_path, pre_train=False):
        train_data = ds.My_Dataset(data_path)
        self.train_loader = torch.utils.data.DataLoader(train_data,
                                                        batch_size=config.batch_size,
                                                        shuffle=True)

        if pre_train == False:
            print('无预训练模型,开始重新训练')
            self.net = DeepLabv3_plus(nInputChannels=3, n_classes=2, os=16, _print=True).to(device)
            self.net.train()
        else:
            print('加载预训练模型')
            self.net = torch.load('./model/net.pth').to(device)
            self.net.train()
        self.opt = torch.optim.Adam(self.net.parameters(), lr=config.lr)

    def train_once(self):
        losss = 0
        ss = 0
        iteration_num = self.train_loader.dataset.__len__()
        for s, (img, label) in enumerate(self.train_loader):
            img = img.to(device)
            label = label.to(device)
            if img.size(0) == 1:
                return losss

            img_ = self.net(img)
            loss1 = loss_f1(img_, label)
            loss2 = loss_f2(img_, label)
            loss3 = loss_f3(img_, label)
            loss = loss1 + loss2 + loss3

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            losss += loss.item()
            ss = ss + img.size(0)
            print(ss, '/', iteration_num, loss.item())
        return losss

    def train(self):
        loss_l = 100
        for i in range(config.epoches):
            print('current epoch:', i + 1)
            losss = self.train_once()
            print('loss: ', losss)
            if loss_l >= losss:
                loss_l = losss
                torch.save(self.net, './model/net.pth')
                print('save pth success')
                print()


if __name__ == '__main__':
    data_path = r'D:\blog_project\guligedong_segmentation\DATA'
    trainer = Train(data_path, pre_train=True)
    trainer.train()

然后让我们愉快的训练起来吧

另外整个项目的代码和图片,等我明天写完推理的代码,在一起给大家打包好,大家要是等不及了的话,就自己标注图片,然后训练也可以。同时,如果要用diceloss的话,大家可能要去之前的deeplabv3+模型结构那一篇,在模型的最后一层加入一个nn.Sigmoid()。

那么,今天的数据准备和训练就结束了

 值得一提的是,昨天的文章入选全站的热榜的第26名,10W热度,感谢各位~~

至此,敬礼,salute!!!

彩蛋福利图,我之前文章里面的模特,咩咩狗长大了,给各位看看,哈哈哈

 

猜你喜欢

转载自blog.csdn.net/guligedong/article/details/121224171