One Shot Learning with Siamese Networks

One Shot Learning with Siamese Networks

今天正好看懂了一篇关于one shot learning的论文,参考了小伙伴们的实现方式,在此Mark以下,以免日后心生疑虑。
本文主要分为以下几部分的内容:

  • 图片显示的辅助函数
  • pytorcch的数据的存储方式以及读取方式
  • 网络的构造
  • 网络的训练
  • 小规模数据即实验以及分析

在分步骤解说之前,我先列一下代码里面需要用到的包,如果大家最后要通读代码的话可以把这些包也带上,直接和后面各个部分的代码一起,放到pytorch这个框架下面运行就好了,来,给你。

%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

一、图片显示的辅助函数

刚开始对pytorch这个框架不是很熟悉,所以看一个简单的图片处理的封装函数就很费劲,所以在此Mark一下,用来纪念我纯纯的智商。放代码块:

def imshow(img,text,should_save=False):
    npimg = img.numpy()  # 一开始不明白为什么非要用img.numpy()将img转化为numpy数组的形式,最后才发现。原来在pytorch中图片数据刚传进来的时候,是torch.Tensor的数据格式,如果要用plt.imshow()来进行显示的话就必须是numpy数组的形式,所以在此做了转换。
    plt.axis("off")  # 不显示图片的坐标轴
    if text:    # 如果有要显示的文字,则调用plt.text()进行显示
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    # plt.text()各个参数的解释,(文字显示的横坐标x,纵坐标y,文字内容,斜体,粗体,给显示文字加边框(背景色:白色,透明度:0.8,上下左右填充:10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 这个方面也是很纳闷啊,一直不知道np.transpose(npimg,(1,2,0))中的参数(1,2,0)是干什么吃的
    #最后才发现,plt.imshow()索要显示的图片格式是(imagesize,imagesize,channels)但是我们读入的batch图片数据是(channels,imagesize,imagesize),所以需要参数(1,2,0,将(channels,imagesize,imagesize)------->(imagesize,imagesize,channels)转化,这样plt.imshow()就可以正常显示了。
    plt.show()    
    # 到最后有时间再举个例子。

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

class Config():
    training_dir = "./data/faces/training/"
    testing_dir = "./data/faces/testing/"
    train_batch_size = 64
    train_number_epochs = 100

个人觉得,既然是学习和做笔记,就应该把代码说的清清楚楚明明白白,要不然下次再思考的时候还是不明白,又要浪费大量的时间在查找旧知识,那样有点不合适吧。所以我对代码进行了详细的解释。如果您比较熟悉,您可以略过呢。

二、 pytorcch的数据的存储方式以及读取方式

对于读取数据的代码段,猜虽然可以猜出来是什么意思,但是还是想把它弄清楚。本次读取数据的工具主要涉及到以下几种。

  1. torch.utils.data.Dataset与Dataloader
  2. torchvision ImageFolder
1、torch.utils.dataDataset与Dataloader

1、在pytorch 中提供了一个数据的读取方法,其由两个类构成:torch.utils.data.Dataset和torch,utils.data.DataLoader
2、我们要定义自己的数据读取方法就要继承torch.utils.data.Dataset类,并将其封装到DataLoader中
3、torch.utils.data.Dataset表示的数据集的类,继承该类可以重载其中的方法,实现多种数据的读取以及预处理的方式。
4、torch.utils.data.DataLoader封装了data对象,实现单(多)进程的迭代输出(负责每次取出一个batch进行训练)

2、如何继承torch.utils.dataDataset类并重写其方法呢

1、要自定义自己的Dataset类,至少要重载两个方法。getitemlen
2、 len 是返回数据集的大小
3、getitem 实现索引数据集中的某一个数据(这个数据的组合形式是你自己可以确定的)
4、除了这两个基本功能,在getitem 中还可以对数据进行预处理,或者通过imageFolder对象在硬盘中读取数据,对于超大的数据集还可以使用lmdb进行处理。
下面我们来看代码:

#按照上述理解,这里定义一个继承自DataSet的类,叫做SiameseNetworkDataset.
class SiameseNetworkDataset(Dataset):  

    def __init__(self,imageFolderDataset,transform=None,should_invert=True):  # 构造函数
        self.imageFolderDataset = imageFolderDataset    #这里留到下一段再解释  
        self.transform = transform  
        self.should_invert = should_invert

    # 这里的__getitem__主要是在SiameseNetworkDataset被封装为DataLoader后,相应的DataLoader用来取数据的。这里如何定义,如何返回数据决定了DataLoader返回数据
    def __getitem__(self,index):   
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1)  ##这里就是如何保证在孪生网络中保证数据的正负样本基本保持1:1的关键所在。
        if should_get_same_class:  # 这里保证数据属于同一类别
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] == img1_tuple[1]:   # (imagespath,target)---->(0,1)  若target的值相等,则表示这两张图片属于同一类
                    break
        else:   # 这里选取数据属于不同的类别,严格来说这里还要判断  img0_tuple[1] != img1_tuple[1]才行,但是鉴于相等的概率太小了,所以这里就简略这样写好了
            img1_tuple = random.choice(self.imageFolderDataset.imgs)   

        img0 = Image.open(img0_tuple[0])  ## 读到这里就可以发现,img0_tuple[0]里面存的其实就是img0这张图片的地址,而img0_tuple[1]里面存的是这张图片的标签。啦啦啦
        img1 = Image.open(img1_tuple[0]) 
        img0 = img0.convert("L")  # 将 img0转化为灰度图
        img1 = img1.convert("L")

        if self.should_invert: # 是否将图片进行反转
            img0 = PIL.ImageOps.invert(img0)  
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:  # 由传入进来的transform函数做用在这两张图片上。
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        # 返回值的形式(图片0,图片1,0或者1)
        return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))

    def __len__(self): # 返回数据集的长度,当封装为DataLoader时。这里的长度为batch_size的大小。这个点值得注意。
        return len(self.imageFolderDataset.imgs)  
3、第二部中,用imageFolder来管理数据集的文件名和标签名,我们来看看imageFolder的功能及其用法:

torch.datasets包中的ImageFolder支持我们直接从硬盘中按照固定路径格式载入每张数据,其格式如下:

根目录/类别/图像: 对应于存盘文件的类型 文件夹/文件夹/图像的名称+后缀

 root/0/000.png   
 root/0/001.png
 root/0/002.png 
 ......
 root/8/120.png 
 root/8/121.png
 root/8/122.png 

那么经过imageFolder后的数据格式为:
图片格式:(路径 标签)

 (root/0/000.png   0 )
 (root/0/001.png   0 )
 (root/0/002.png   0 )
 ......
 (root/8/120.png   8 )
 (root/8/121.png   8 )
 (root/8/122.png   8 )

关于其使用方法,请看以下的代码:

扫描二维码关注公众号,回复: 194499 查看本文章
folder_dataset = dset.ImageFolder(root=Config.training_dir) # ImageFolder的作用是将training_dir下面的图片的(路径,标签)存到folder_dataset.imgs中 
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
                                        transform=transforms.Compose([transforms.Scale((100,100)), transforms.ToTensor()]),
                                        should_invert=False)
4、将定义好的SiameseNetworkDataSet封装到DataLoader中,看代码:
# DataLoader的参数分别为(DataSet对象、是否需要打乱数据、进程的并发数、每次读取的batch_size的大小)
# 这里注释以下为什么需要shuffle=True,因为在后续的网络定义过程中,需要用到batchNormlize,即保持数据的独立同分布,所以就需要读取的数据是尽量的混合均匀的,不不是很均匀,batchNormalize的均值和方差会出现漂移的,进而影响训练的效果。
 train_dataloader = DataLoader(siamese_dataset,shuffle=True, num_workers=8, batch_size=Config.train_batch_size)

三、网络的构造+损失函数

网络的构造,看代码及注释:

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),   # 这里注意如果这里的激活函数如果是sigmod的话,那么BatchNorm2d应该在nn.sigmod函数之前,具体为什么请看看考网址5
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )
        # 全连接层
        self.fc1 = nn.Sequential( 
            nn.Linear(8*100*100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5))

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2


class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 

     # 以上函数若有不懂的可以查看参考网址里面的链接
        return loss_contrastive

四、网络的训练

代码如下:

#第一步,将数据拿来
train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=Config.train_batch_size)
# 第二部,把图拿来
net = SiameseNetwork().cuda()   # 使用cuda函数以后的所有运算均会调用gpu来运算。
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )
#第三部,定义好咱们训练过程中需要的观察和记录的量   
counter = []
loss_history = [] 
iteration_number= 0
# 叮咚~~~~开始训练了。
for epoch in range(0,Config.train_number_epochs):
    for i, data in enumerate(train_dataloader,0):
        img0, img1 , label = data
        img0, img1 , label = Variable(img0).cuda(), Variable(img1).cuda() , Variable(label).cuda()
        # forward + backward +optimize
        output1,output2 = net(img0,img1)  # 1、这里forward函数是自动调用的
        optimizer.zero_grad()  # 因为梯度是累加的,随意这里的梯度最开始要设置为0
        loss_contrastive = criterion(output1,output2,label)  # 这里forward函数是自动调用的
        # 2、反向传播  loss_contrastive的叶子节点可以返回求导之后的值 
        # 这里也可以看出,backward 这个方法也是 释放一些资源的 的一个标志,如果不需要 backward 的话,一定要记得设置网络为eval
        loss_contrastive.backward() 
        optimizer.step()  # 3、利用优化器更新参数
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.data[0]))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.data[0])  # 这里的data[0]是指什么我还没弄明白。
show_plot(counter,loss_history)

今天就学习到这里,下次我想用tensorflow把它实现


参考网址
1、pytorch torchvision transform
2、Deep Learning」理解Pytorch中的「torch.nn」 ReflectionPad2d
3、Pytorch笔记05-自定义数据读取方式orch.utils.data.Dataset与Dataloader
4、pytorch的官方文档
5、Batch Normalization导读
6、深度学习中批归一化的陷阱
7、多通道(比如RGB三通道)卷积过程
8、F.pairwise_distance
9、torch.clamp
10、ImageFolder返回值及其作用

猜你喜欢

转载自blog.csdn.net/xiongchengluo1129/article/details/79082334