MTCNN网络的训练

MTCNN三个网络分别训练,其中置信度和偏置是用不同的样本进行训练,置信度用正样本和负样本进行训练,偏移用正样本和部分样本进行训练:

import torch
from torch import nn
from torch import optim
from DataSet import MyDataSet
from Net import Pnet,Rnet,Onet
import DataSet
from torch.utils import data

class Train:
    def __init__(self,p_textpath,n_textpath,t_textpath,p_imgpath,n_imgpath,t_imgpath,net):
        self.p_textpath = p_textpath
        self.p_imgpath = p_imgpath
        self.n_textpath = n_textpath
        self.n_imgpath = n_imgpath
        self.t_textpath = t_textpath
        self.t_imgpath = t_imgpath
        self.net = net
        #创建训练数据集
        dataset = MyDataSet(p_textpath,n_textpath,t_textpath,p_imgpath,n_imgpath,t_imgpath)
        self.dataloader = data.DataLoader(dataset,batch_size=10,shuffle=True)

    def train(self):
        if self.net == 'pnet':
            net = Pnet()
        elif self.net == 'rnet':
            net = Rnet()
        elif self.net == 'onet':
            net = Onet()
        optimizer = optim.Adam(net.parameters())
        conf_loss_fun = nn.BCELoss()
        off_loos_fun = nn.MSELoss()
        for epoch in range(1000):
            imgdata, conf, offset = DataSet.GetIter(self.dataloader)
            confidence,offset_out = net(imgdata)
            #置信度的损失需要正负样本
            #获得置信度小于2的掩码
            conn_mask = torch.lt(conf,2)
            #得到符合条件的置信度
            conf_ = conf[conn_mask]
            confidence_ = confidence[conn_mask]
            #偏移的损失需要正样本和部分样本
            #得到置信度大于0的掩码
            off_mask = torch.gt(conf,0)
            #得到符合条件的偏移
            offset = offset[off_mask[:,0]]
            offset_out = offset_out[off_mask[:,0]]
            conf_loss = conf_loss_fun(confidence_,conf_)
            off_loss = off_loos_fun(offset_out,offset)
            loss = conf_loss + off_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(loss)

train=Train(DataSet.p_48txtpath,DataSet.n_48txtpath,DataSet.t_48txtpath,DataSet.p_48imgpath,DataSet.n_48imgpath,DataSet.t_48imgpath,'onet')
train.train()

猜你喜欢

转载自blog.csdn.net/weixin_38241876/article/details/91959102