One Shot Learning with Siamese Networks

One Shot Learning with Siamese Networks

Today, I just read a paper about one shot learning, referring to the implementation methods of my friends. Here is the mark, so as to avoid doubts in the future.
This article is mainly divided into the following parts:

  • Helper function for image display
  • How to store and read data in pytorcch
  • The structure of the network
  • training of the network
  • Small-scale data i.e. experimentation and analysis

Before the step-by-step explanation, I will list the packages that need to be used in the code. If you want to read the code in the end, you can bring these packages, directly with the codes of the following parts, and put them under the framework of pytorch to run. All right, here I go.

%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

1. Auxiliary function for picture display

I am not very familiar with the pytorch framework at the beginning, so it is very difficult to see a simple image processing package function, so I will mark it here to commemorate my pure IQ. Put the code block:

def imshow(img,text,should_save=False):
    npimg = img.numpy()  # 一开始不明白为什么非要用img.numpy()将img转化为numpy数组的形式,最后才发现。原来在pytorch中图片数据刚传进来的时候,是torch.Tensor的数据格式,如果要用plt.imshow()来进行显示的话就必须是numpy数组的形式,所以在此做了转换。
    plt.axis("off")  # 不显示图片的坐标轴
    if text:    # 如果有要显示的文字,则调用plt.text()进行显示
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    # plt.text()各个参数的解释,(文字显示的横坐标x,纵坐标y,文字内容,斜体,粗体,给显示文字加边框(背景色:白色,透明度:0.8,上下左右填充:10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 这个方面也是很纳闷啊,一直不知道np.transpose(npimg,(1,2,0))中的参数(1,2,0)是干什么吃的
    #最后才发现,plt.imshow()索要显示的图片格式是(imagesize,imagesize,channels)但是我们读入的batch图片数据是(channels,imagesize,imagesize),所以需要参数(1,2,0,将(channels,imagesize,imagesize)------->(imagesize,imagesize,channels)转化,这样plt.imshow()就可以正常显示了。
    plt.show()    
    # 到最后有时间再举个例子。

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

class Config():
    training_dir = "./data/faces/training/"
    testing_dir = "./data/faces/testing/"
    train_batch_size = 64
    train_number_epochs = 100

Personally, I think that since it is studying and taking notes, the code should be clearly stated, otherwise the next time you think about it, you will still not understand it, and you will waste a lot of time searching for old knowledge, which is a bit inappropriate. . So I explained the code in detail. If you are more familiar, you can skip it.

Second, the storage method and reading method of pytorcch data

For the code segment that reads the data, I can guess what it means, but I still want to figure it out. The tools for reading data this time mainly involve the following.

  1. torch.utils.data.Dataset与Dataloader
  2. torchvision ImageFolder
1、torch.utils.dataDataset与Dataloader

1. A data reading method is provided in pytorch, which consists of two classes: torch.utils.data.Dataset and torch,utils.data.DataLoader
2. If we want to define our own data reading method, we need to inherit torch.utils.data.Dataset class, and encapsulate it into DataLoader
3. The class of the dataset represented by torch.utils.data.Dataset, inheriting this class can overload the methods in it to realize the reading of various data and way of preprocessing.
4. torch.utils.data.DataLoader encapsulates the data object to realize the iterative output of a single (multiple) process (responsible for taking out a batch for training each time)

2. How to inherit the torch.utils.dataDataset class and override its methods?

1. To customize your own Dataset class, you must overload at least two methods. getitem , len
2, len is the size of the returned dataset
3. getitem implements a certain data in the indexed dataset (the combination of this data can be determined by yourself)
4. In addition to these two basic functions, getitem can also Preprocess the data, or read the data from the hard disk through the imageFolder object, and lmdb can also be used for processing large data sets.
Let's look at the code below:

#按照上述理解,这里定义一个继承自DataSet的类,叫做SiameseNetworkDataset.
class SiameseNetworkDataset(Dataset):  

    def __init__(self,imageFolderDataset,transform=None,should_invert=True):  # 构造函数
        self.imageFolderDataset = imageFolderDataset    #这里留到下一段再解释  
        self.transform = transform  
        self.should_invert = should_invert

    # 这里的__getitem__主要是在SiameseNetworkDataset被封装为DataLoader后,相应的DataLoader用来取数据的。这里如何定义,如何返回数据决定了DataLoader返回数据
    def __getitem__(self,index):   
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1)  ##这里就是如何保证在孪生网络中保证数据的正负样本基本保持1:1的关键所在。
        if should_get_same_class:  # 这里保证数据属于同一类别
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] == img1_tuple[1]:   # (imagespath,target)---->(0,1)  若target的值相等,则表示这两张图片属于同一类
                    break
        else:   # 这里选取数据属于不同的类别,严格来说这里还要判断  img0_tuple[1] != img1_tuple[1]才行,但是鉴于相等的概率太小了,所以这里就简略这样写好了
            img1_tuple = random.choice(self.imageFolderDataset.imgs)   

        img0 = Image.open(img0_tuple[0])  ## 读到这里就可以发现,img0_tuple[0]里面存的其实就是img0这张图片的地址,而img0_tuple[1]里面存的是这张图片的标签。啦啦啦
        img1 = Image.open(img1_tuple[0]) 
        img0 = img0.convert("L")  # 将 img0转化为灰度图
        img1 = img1.convert("L")

        if self.should_invert: # 是否将图片进行反转
            img0 = PIL.ImageOps.invert(img0)  
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:  # 由传入进来的transform函数做用在这两张图片上。
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        # 返回值的形式(图片0,图片1,0或者1)
        return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))

    def __len__(self): # 返回数据集的长度,当封装为DataLoader时。这里的长度为batch_size的大小。这个点值得注意。
        return len(self.imageFolderDataset.imgs)  
3. In the second part, imageFolder is used to manage the file name and label name of the dataset. Let's take a look at the function and usage of imageFolder:

The ImageFolder in the torch.datasets package allows us to load each data directly from the hard disk in a fixed path format, the format is as follows:

Root directory/category/image: Name + suffix of the type folder/folder/image corresponding to the save file

 root/0/000.png   
 root/0/001.png
 root/0/002.png 
 ......
 root/8/120.png 
 root/8/121.png
 root/8/122.png 

Then the data format after imageFolder is:
Image format: (path label)

 (root/0/000.png   0 )
 (root/0/001.png   0 )
 (root/0/002.png   0 )
 ......
 (root/8/120.png   8 )
 (root/8/121.png   8 )
 (root/8/122.png   8 )

For how to use it, see the following code:

folder_dataset = dset.ImageFolder(root=Config.training_dir) # ImageFolder的作用是将training_dir下面的图片的(路径,标签)存到folder_dataset.imgs中 
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
                                        transform=transforms.Compose([transforms.Scale((100,100)), transforms.ToTensor()]),
                                        should_invert=False)
4. Encapsulate the defined SiameseNetworkDataSet into DataLoader and see the code:
# DataLoader的参数分别为(DataSet对象、是否需要打乱数据、进程的并发数、每次读取的batch_size的大小)
# 这里注释以下为什么需要shuffle=True,因为在后续的网络定义过程中,需要用到batchNormlize,即保持数据的独立同分布,所以就需要读取的数据是尽量的混合均匀的,不不是很均匀,batchNormalize的均值和方差会出现漂移的,进而影响训练的效果。
 train_dataloader = DataLoader(siamese_dataset,shuffle=True, num_workers=8, batch_size=Config.train_batch_size)

3. Network construction + loss function

The construction of the network, see the code and comments:

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),   # 这里注意如果这里的激活函数如果是sigmod的话,那么BatchNorm2d应该在nn.sigmod函数之前,具体为什么请看看考网址5
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )
        # 全连接层
        self.fc1 = nn.Sequential( 
            nn.Linear(8*100*100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5))

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2


class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 

     # 以上函数若有不懂的可以查看参考网址里面的链接
        return loss_contrastive

Fourth, the training of the network

code show as below:

#第一步,将数据拿来
train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=Config.train_batch_size)
# 第二部,把图拿来
net = SiameseNetwork().cuda()   # 使用cuda函数以后的所有运算均会调用gpu来运算。
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )
#第三部,定义好咱们训练过程中需要的观察和记录的量   
counter = []
loss_history = [] 
iteration_number= 0
# 叮咚~~~~开始训练了。
for epoch in range(0,Config.train_number_epochs):
    for i, data in enumerate(train_dataloader,0):
        img0, img1 , label = data
        img0, img1 , label = Variable(img0).cuda(), Variable(img1).cuda() , Variable(label).cuda()
        # forward + backward +optimize
        output1,output2 = net(img0,img1)  # 1、这里forward函数是自动调用的
        optimizer.zero_grad()  # 因为梯度是累加的,随意这里的梯度最开始要设置为0
        loss_contrastive = criterion(output1,output2,label)  # 这里forward函数是自动调用的
        # 2、反向传播  loss_contrastive的叶子节点可以返回求导之后的值 
        # 这里也可以看出,backward 这个方法也是 释放一些资源的 的一个标志,如果不需要 backward 的话,一定要记得设置网络为eval
        loss_contrastive.backward() 
        optimizer.step()  # 3、利用优化器更新参数
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.data[0]))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.data[0])  # 这里的data[0]是指什么我还没弄明白。
show_plot(counter,loss_history)

Learned here today, next time I want to implement it with tensorflow


Refer to URL
1, pytorch torchvision transform
2, Deep Learning" to understand "torch.nn" in Pytorch ReflectionPad2d
3, Pytorch Note 05 - Custom data reading method orch.utils.data.Dataset and Dataloader
4, pytorch's official documentation
5 , Batch Normalization Guide
6, Batch normalization traps in deep learning
7, Multi-channel (such as RGB three-channel) convolution process
8, F.pairwise_distance
9, torch.clamp
10, ImageFolder return value and its role

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=326009093&siteId=291194637