1. Background
When there is less labeled data and a lot of unlabeled data, and the labeling cost is high, semi-supervised learning training can be considered. First, use pseudo-labeling technology to pseudo-label unlabeled pictures, and then use labeled data and pseudo-labeled data to mix the training model. It is worth noting that, to ensure that each mini-batch contains real tags and pseudo tags, this article will take you to implement it in code.
2. Implementation methods and steps
First look at the pseudo-label technology, refer to here , as shown in the following figure:
3. Code implementation
The first is to generate pseudo-labels, which is relatively simple for classification and target detection, so I won't go into details here.
The following is realized: how to ensure that real labels and pseudo labels exist at the same time in each mini-batch, and to control their ratio, take classification as an example to illustrate.
The first step is to load the revised data, as follows:
import os
import torch
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torchvision
import cv2
import sys
import random
from PIL import Image
from data_augment import gussian_blur, random_crop
class Dataset(data.Dataset):
def __init__(self, img_list, img_list1, phase='train'):
self.phase = phase
# 标注的标签
with open(img_list, 'r') as fd:
imgs = fd.readlines()
imgs = [img.rstrip("\n") for img in imgs]
random.shuffle(imgs)
self.imgs = imgs
# 伪标签(模拟的)
with open(img_list1, 'r') as fd:
fake_imgs = fd.readlines()
fake_imgs = [img.rstrip("\n") for img in fake_imgs]
random.shuffle(fake_imgs)
self.fake_imgs = fake_imgs
normalize = T.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
if self.phase == 'train':
self.transforms = T.Compose([
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.ToTensor(),
normalize
])
def __getitem__(self, index):
sample = self.imgs[index]
splits = sample.split()
img_path = splits[0]
# data augment
data = cv2.imread(img_path)
data = random_crop(data, 0.2)
data = gussian_blur(data, 0.2)
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
data = Image.fromarray(data)
data = data.resize((224, 224))
data = self.transforms(data)
label = np.int32(splits[1])
# 取伪数据和伪标签
fake_datas, fake_labels = [], []
for i in range(2):
fake_sample = self.fake_imgs[(index+i)%len(self.fake_imgs)]
fake_splits = fake_sample.split()
fake_img_path = fake_splits[0]
fake_data = cv2.imread(fake_img_path)
fake_data = cv2.cvtColor(fake_data, cv2.COLOR_BGR2RGB)
fake_data = Image.fromarray(fake_data)
fake_data = fake_data.resize((224, 224))
fake_data = self.transforms(fake_data)
fake_label = np.int32(fake_splits[1])
fake_datas.append(fake_data.float())
fake_labels.append(fake_label)
return data.float(), label, fake_datas, fake_labels
def __len__(self):
return len(self.imgs)
The second step, the realization in the main training program, is as follows:
def train(epoch, net, trainloader, optimizer, criterion):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
batch_id = 0
for (inputs, targets, fake_inputs, fake_targets) in tqdm(trainloader):
# 将真标签和伪标签融合
fake_inputs.append(inputs)
fake_targets.append(targets)
inputs = torch.cat(fake_inputs, dim=0)
targets = torch.cat(fake_targets, dim=0)
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets.long())
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets.long()).sum().item()
iters = epoch * len(trainloader) + batch_id
if iters % 10 == 0:
acc = predicted.eq(targets.long()).sum().item()*1.0/targets.shape[0]
los = loss*1.0/targets.shape[0]
#tensor_board.visual_loss("train_loss", los, iters)
#tensor_board.visual_acc("train_acc", acc, iters)
batch_id += 1
It's that simple, please refer to my other blog for the theoretical part
Related: https://blog.csdn.net/p_lart/article/details/100128353