半监督学习实战——标注数据和伪标签数据混合训练

1. 背景

      当标注数据较少,而未标注的数据很多,并且标注成本很高时,可以考虑半监督学习训练。首先,采用伪标签技术把没有标注的的图片打上伪标签,然后用标注数据和伪标签数据混合训练模型。值得注意的是,要保证每个mini-batch中含有真实标签和伪标签,本文带你用代码实现。

2. 实现方法及步骤

       首先看看伪标签技术,参考这里,如下图所示:

3. 代码实现

      首先是生成伪标签,对于分类和目标检测而言都比较简单,这里不赘述。

      下面实现的是:如何在每个mini-batch中保证同时存在真实标签和伪标签,并且控制他们的比例,以分类为例进行说明。

       第一步,需要修稿数据加载程序,如下:

import os
import torch
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torchvision
import cv2
import sys
import random
from PIL import Image
from data_augment import gussian_blur, random_crop

class Dataset(data.Dataset):
    def __init__(self, img_list, img_list1, phase='train'):
        self.phase = phase

        # 标注的标签
        with open(img_list, 'r') as fd:
            imgs = fd.readlines()
        imgs = [img.rstrip("\n") for img in imgs]
        random.shuffle(imgs)
        self.imgs = imgs

        # 伪标签(模拟的)
        with open(img_list1, 'r') as fd:
            fake_imgs = fd.readlines()
        fake_imgs = [img.rstrip("\n") for img in fake_imgs]
        random.shuffle(fake_imgs)
        self.fake_imgs = fake_imgs


        normalize = T.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])

        if self.phase == 'train':
            self.transforms = T.Compose([
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                normalize
            ])
        else:
            self.transforms = T.Compose([
                T.ToTensor(),
                normalize
            ])

    def __getitem__(self, index):
        sample = self.imgs[index]
        splits = sample.split()
        img_path = splits[0]

        # data augment
        data = cv2.imread(img_path)
        data = random_crop(data, 0.2)
        data = gussian_blur(data, 0.2)

        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
        data = Image.fromarray(data)
        
        data = data.resize((224, 224))
        data = self.transforms(data)
        label = np.int32(splits[1])

        # 取伪数据和伪标签
        fake_datas, fake_labels = [], []
        for i in range(2):
            fake_sample = self.fake_imgs[(index+i)%len(self.fake_imgs)]
            fake_splits = fake_sample.split()
            fake_img_path = fake_splits[0]

            fake_data = cv2.imread(fake_img_path)
            fake_data = cv2.cvtColor(fake_data, cv2.COLOR_BGR2RGB)
            fake_data = Image.fromarray(fake_data)
            fake_data = fake_data.resize((224, 224))
            fake_data = self.transforms(fake_data)

            fake_label = np.int32(fake_splits[1])

            fake_datas.append(fake_data.float())
            fake_labels.append(fake_label)

        return data.float(), label, fake_datas, fake_labels

    def __len__(self):
        return len(self.imgs)

       第二步,在训练主程序中的实现,如下:

def train(epoch, net, trainloader, optimizer, criterion):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    batch_id = 0
    for (inputs, targets, fake_inputs, fake_targets) in tqdm(trainloader):
    
        # 将真标签和伪标签融合
        fake_inputs.append(inputs)
        fake_targets.append(targets)

        inputs = torch.cat(fake_inputs, dim=0)
        targets = torch.cat(fake_targets, dim=0)

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets.long())
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets.long()).sum().item()

        iters = epoch * len(trainloader) + batch_id
        if iters % 10 == 0:
            acc = predicted.eq(targets.long()).sum().item()*1.0/targets.shape[0]
            los = loss*1.0/targets.shape[0]
            #tensor_board.visual_loss("train_loss", los, iters)
            #tensor_board.visual_acc("train_acc", acc, iters)
        batch_id += 1

 就是这么简单,理论部分请参考我的另一篇博客

相关:https://blog.csdn.net/p_lart/article/details/100128353

猜你喜欢

转载自blog.csdn.net/Guo_Python/article/details/107980688