[Pytorch Neural Network Practical Case] 37 Maximize Deep Mutual Trust Information Model DIM to search for the most relevant and least relevant pictures

The image searcher is divided into two parts: feature extraction and matching of images, and feature extraction of images is the key. Feature extraction will be implemented using an unsupervised model-based feature extraction method, namely the Maximizing Deep Mutual Information (DeepInfoMax, DIM) method.

1 Introduction to the maximum depth mutual trust information model DIM

In the DIM model, the autoencoder and adversarial neural network are combined, and the loss function uses the combination of MINE and f-GAN method. On top of this, the DM model is trained from three losses: global loss, local loss and prior loss.

1.1 Principle of DIM model

A well-performing encoder should be able to extract the most unique and specific information in a sample, rather than simply pursuing a reconstruction error that is too small. The unique information of the sample can be measured using mutual information (Mutual Information, MI).

Therefore, in the DIM model, the objective function of the encoder is not to minimize the MSE of the input and output, but to maximize the mutual information of the input and output.

1.2 The main idea of ​​DIM model

The mutual information solution in the DIM model mainly comes from the MINE method, which calculates the mutual information between the input sample and the transformation vector output by the encoder, and realizes the training of the model by maximizing the mutual information.

1.2.1 Two constraints of DIM model in unsupervised training.

  1. Maximize the mutual information between the input information and the high-level feature vector : If the low-dimensional features output by the model can represent the input samples, then the mutual information between the feature distribution and the input sample distribution must be the largest.
  2. Adversarial matching prior distribution : The high-level features output by the encoder should be closer to a Gaussian distribution, and the discriminator should distinguish the data distribution generated by the encoder from the Gaussian distribution.

When implemented, the DlM model uses three discriminators, which constrain the output of the encoder from three perspectives: local mutual information maximization, global mutual information maximization and prior distribution matching minimization. (Paper arXv: 1808.06670, 2018)

1.3 The principle of local and global mutual information maximization constraints

Many representation learning only uses the data space that has been explored (called the pixel level), indicating that representation learning will not be good for training when a small portion of the data is very concerned with the semantic level.
    For pictures, its relevance is more local. Image recognition and classification should be a process from local to whole, that is, global features are more suitable for reconstruction, and local features are more suitable for downstream classification tasks.
The local feature can be understood as the feature map obtained after convolution, and the global feature can be understood as the feature vector obtained by encoding the feature map.

The DIM model performs mutual information computation on the input and output from both local and global perspectives.

1.4 Prior distribution matching minimization constraint principle

The purpose of prior matching is to constrain the encoder-generated vector form to be closer to a Gaussian distribution.

The main idea of ​​the encoder of the DIM model is: while encoding the input data into a feature vector, it is also hoped that the feature vector obeys the standard Gaussian distribution. This approach makes the encoding space more regular, and even facilitates decoupling features for subsequent learning, which is the same mission as the encoder in variational autoencoding.

Therefore, the principle of variational autoencoder neural network is introduced into the DIM model, and the Gaussian distribution is regarded as the prior distribution to constrain the vector output by the encoder.

2 Structure of the DIM model

2.1 DIM model structure diagram

The structure of the DIM model The DIM model consists of 4 sub-models: 1 encoder and 3 discriminators. The main function of the decoder is to extract features from the graph, and the three discriminators need to constrain the output of the encoder from three perspectives: local, global, and prior matching.

2.2 Special features of the DlM model

    In the actual implementation process of the DlM model, the maximum mutual information calculation is not directly performed on the original input data and the feature data output by the encoder, but the feature map in the intermediate process of the encoder and the final feature data are used to perform mutual information calculation. .

    According to the MINE method, the method of calculating mutual information using a neural network can be converted into calculating the divergence between the joint distribution and the marginal distribution of the two data sets, that is, the result of the discriminator's processing of feature maps and feature data is regarded as a joint distribution, and disordered The latter feature map and feature data are input into the discriminator to obtain the edge distribution.

The DIM model scrambles the batch sequence of the feature map and uses the prompt feature vector output by the encoder as the input of the discriminator, that is, the feature map and the feature vector of the input discriminator are independent (destroy the correspondence between the feature map and the feature vector) , see the introduction to the principle of mutual information neural estimation.

2.3 Global Discriminator Model

As shown in Figure 8-29, the global discriminator has two input values: feature map and feature data y. In the process of calculating mutual information, the jointly distributed feature map and feature data y both come from the output of the encoding neural network. The feature map for calculating the edge distribution is obtained by changing the batch order of the feature map, and the feature data y comes from the output of the encoding neural network, as shown in Figure 8-30.

In the global discriminator, the specific processing steps are as follows.
(1) Use the convolutional layer to process the feature map to obtain global features.
(2) Connect the global feature and feature data y with the torch.cat() function.
(3) Input the connected result into the fully connected network (judgment on two global features), and finally output the judgment result (one-dimensional vector).

2.4 Local discriminator model

As shown in Figure 8-29, the input value of the local discriminator is a special synthetic vector: the feature data y output by the encoder is copied into m×m copies according to the size of the feature map. Let each pixel in the feature map be connected to the global feature data ν output by the encoder. In this way, what the discriminator does is to calculate the mutual information between each pixel and the global feature vector. Therefore, this discriminator is called a local discriminator.
In the local discriminator, the joint distribution and edge distribution for calculating mutual information are consistent with the global discriminator. As shown in Figure 8-31, the local discriminator mainly uses a 1×1 convolution operation (the stride is also 1). Because this convolution operation does not change the size of the feature map (just a transformation of the number of channels), the final output of the discriminator is also a value of size m×m.

The local discriminator finally changes the number of channels to 1 by performing a multi-layer 1×1 convolution operation, which is used as the final discrimination result. This process can be understood as calculating mutual information for each pixel and global features at the same time.

2.5 Prior discriminator model

The prior discriminator model is mainly that the vector generated by the auxiliary encoder is close to the Gaussian distribution, which is consistent with the common adversarial neural network. The output result of the prior discriminator model is only 0 or 1: let the discriminator judge the data sampled from the Gaussian distribution as true (1), and judge the feature vector output by the encoder as false (0), as shown in Figure 8-32 .

The prior discriminator model is shown in Figure 8-32. The input to the prior discriminator model is only one feature vector. Its structure mainly uses a fully connected neural network, which will eventually output a "true" or "false" decision result.

2.6 Loss function

    In the DIM model, the KL divergence in the MINE method is replaced by the JS divergence as a measure of mutual information. The reason for this is: JS divergence is upper bound, while KL divergence has no upper bound. In contrast, JS divergence is more suitable for use in maximization tasks, because it does not produce particularly large numbers when computing, and the gradient of JS divergence is unbiased.

The calculation formula of JS divergence can be found in f-GAN, see formula (8-46) (the principle is explained in the hint section below formula (8-46)).

 The loss function of the prior discriminator is very simple and consistent with the loss function in the original GAN ​​model (refer to the paper number anXiv:1406.2661, 2014). The weighted summation of the calculation results of the respective loss functions of the three discriminators gives Get the loss function of the entire DM model.

3 Introduction to the actual combat case and code implementation (training model code implementation)

Extract image information using a maximizing deep mutual information model, and use the extracted low-dimensional features to make an image searcher.

3.1 CIFAR dataset

    The dataset used in this example is ClFAR, which is similar to the Fashion-MNIST dataset, and also some pictures. ClFAR is more complex than Fashion-MNIST and consists of color images, which, by contrast, are closer to samples in contact with real scenes.

3.1.1 Composition of the CIFAR dataset

Version of the CIFAR dataset Because the original dataset divides the data into 10 categories, namely airplanes, cars, birds, cats, deer, dogs, frogs, horses, boats, and trucks, the ClFAR datasets often use CIFAR-10 Named, which contains 60,000 color images of 32 pixels × 32 pixels (including 50,000 training images, 10,000 test images), without any type of overlap. Because it is a color image, this dataset is three-channel, with three channels of R, G, and B.

CIFAR has launched a more classified version: ClFAR-100, as can be seen from the name, which divides data into 100 categories. It divides the picture into finer details. Of course, this is a bigger challenge for neural network image recognition. With this data, we can devote all our energy to network optimization.

 3.2 Get the dataset

The ClFAR data set is a packaged file, which is divided into Python and binary bin file packages, which are convenient for different programs to read. The data set used this time is the Python file package in the ClFAR-10 version, and the corresponding file name is "cifar" -10-pyhon.tar.gz". This file can be downloaded manually on the official website, or it can be downloaded through PyTorch's embedded code using a method similar to getting Fashion-MNIST.

3.3 Load and display the CIFAR dataset------DIM_CIRFAR_train.py (Part 1)

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.optim import Adam
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets.cifar import CIFAR10
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torchvision.transforms import ToPILImage
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 1.1 获取数据集并显示数据集
# 指定运算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# 加载数据集
batch_size = 512
data_dir = r'./cifar10/'
# 将CIFAR10数据集下载到本地:共有三份文件,标签说明文件batches.meta,训练样本集data_batch_x(一共五个,包含10000条训练样本),测试样本test.batch
train_dataset = CIFAR10(data_dir,download=True,transform=ToTensor())
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,pin_memory=torch.cuda.is_available())
print("训练样本个数:",len(train_dataset))
# 定义函数用于显示图片
def imshowrow(imgs,nrow):
    plt.figure(dpi=200) # figsize=(9,4)
    # ToPILImage()调用PyTorch的内部转换接口,实现张量===>PLImage类型图片的转换。
    # 该接口主要实现。(1)将张量的每个元素乘以255。(2)将张量的数据类型由FloatTensor转化成uint8。(3)将张量转化成NumPy的ndarray类型。(4)对ndarray对象执行transpose(1,2,0)的操作。(5)利用Image下的fromarray()函数,将ndarray对象转化成PILImage形式。(6)输出PILImage。
    _img = ToPILImage()(torchvision.utils.make_grid(imgs,nrow=nrow)) # 传入PLlmage()接口的是由torchvision.utis.make_grid接口返回的张量对象
    plt.axis('off')
    plt.imshow(_img)
    plt.show()

# 定义标签与对应的字符
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 获取一部分样本用于显示
sample = iter(train_loader)
images,labels = sample.next()
print("样本形状:",np.shape(images))
print('样本标签:',','.join('%2d:%-5s' % (labels[j],classes[labels[j]]) for j in range(len(images[:10]))))
imshowrow(images[:10],nrow=10)

output:

3.5 Defining the DIM Model------DIM_CIRFAR_train.py (Part 2)

# 1.2 定义DIM模型
class Encoder(nn.Module): # 通过多个卷积层对输入数据进行编码,生成64维特征向量
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1) # 输出尺寸29
        self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1) # 输出尺寸26
        self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1) # 输出尺寸23
        self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1) # 输出尺寸20
        self.l1 = nn.Linear(512*20*20, 64)
        # 定义BN层
        self.b1 = nn.BatchNorm2d(128)
        self.b2 = nn.BatchNorm2d(256)
        self.b3 = nn.BatchNorm2d(512)

    def forward(self, x):
        h = F.relu(self.c0(x))
        features = F.relu(self.b1(self.c1(h)))#输出形状[b 128 26 26]
        h = F.relu(self.b2(self.c2(features)))
        h = F.relu(self.b3(self.c3(h)))
        encoded = self.l1(h.view(x.shape[0], -1))# 输出形状[b 64]
        return encoded, features

class DeepInfoMaxLoss(nn.Module): # 实现全局、局部、先验判别器模型的结构设计,合并每个判别器的损失函数,得到总的损失函数
    def __init__(self,alpha=0.5,beta=1.0,gamma=0.1):
        super().__init__()
        # 初始化损失函数的加权参数
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        # 定义局部判别模型
        self.local_d = nn.Sequential(
            nn.Conv2d(192,512,kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512,512,kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512,1,kernel_size=1)
        )
        # 定义先验判别器模型
        self.prior_d = nn.Sequential(
            nn.Linear(64,1000),
            nn.ReLU(True),
            nn.Linear(1000,200),
            nn.ReLU(True),
            nn.Linear(200,1),
            nn.Sigmoid() # 在定义先验判别器模型的结构时,最后一层的激活函数用Sigmoid函数。这是原始GAN模型的标准用法(可以控制输出值的范围为0-1),是与损失函数配套使用的。
        )
        # 定义全局判别器模型
        self.global_d_M = nn.Sequential(
            nn.Conv2d(128,64,kernel_size=3), # 输出形状[b,64,24,24]
            nn.ReLU(True),
            nn.Conv2d(64,32,kernel_size=3), # 输出形状 [b,32,32,22]
            nn.Flatten(),
        )
        self.global_d_fc = nn.Sequential(
            nn.Linear(32*22*22+64,512),
            nn.ReLU(True),
            nn.Linear(512,512),
            nn.ReLU(True),
            nn.Linear(512,1)
        )

    def GlobalD(self, y, M):
        h = self.global_d_M(M)
        h = torch.cat((y, h), dim=1)
        return self.global_d_fc(h)
    def forward(self,y,M,M_prime):
        # 复制特征向量
        y_exp = y.unsqueeze(-1).unsqueeze(-1)
        y_exp = y_exp.expand(-1,-1,26,26) # 输出形状[b,64,26,26]
        # 按照特征图的像素连接特征向量
        y_M = torch.cat((M,y_exp),dim=1) # 输出形状[b,192,26,26]
        y_M_prime = torch.cat((M_prime,y_exp),dim=1)# 输出形状[b,192,26,26]
        # 计算局部互信息---互信息的计算
        Ej = -F.softplus(-self.local_d(y_M)).mean() # 联合分布
        Em = F.softplus(self.local_d(y_M_prime)).mean() # 边缘分布
        LOCAL = (Em - Ej) * self.beta # 最大化互信息---对互信息执行了取反操作。将最大化问题变为最小化问题,在训练过程中,可以使用最小化损失的方法进行处理。
        # 计算全局互信息---互信息的计算
        Ej = -F.softplus(-self.GlobalD(y, M)).mean() # 联合分布
        Em = F.softplus(self.GlobalD(y, M_prime)).mean() # 边缘分布
        GLOBAL = (Em - Ej) * self.alpha # 最大化互信息---对互信息执行了取反操作。将最大化问题变为最小化问题,在训练过程中,可以使用最小化损失的方法进行处理。
        # 计算先验损失
        prior = torch.rand_like(y) # 获得随机数
        term_a = torch.log(self.prior_d(prior)).mean() # GAN损失
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        PRIOR = -(term_a + term_b) * self.gamma # 最大化目标分布---实现了判别器的损失函数。判别器的目标是将真实数据和生成数据的分布最大化,因此,也需要取反,通过最小化损失的方法来实现。
        return LOCAL + GLOBAL + PRIOR

# #### 在训练过程中,梯度可以通过损失函数直接传播到编码器模型,进行联合优化,因此,不需要对编码器额外进行损失函数的定义!

3.6 Instantiate DIM model and train ------ DIM_CIRFAR_train.py (Part 3)

# 1.3 实例化DIM模型并训练:实例化模型按照指定次数迭代训练。在制作边缘分布样本时,将批次特征图的第1条放到最后,以使特征图与特征向量无法对应,实现与按批次打乱顺序等同的效果。
totalepoch = 100 # 指定训练次数
if __name__ == '__main__':
    encoder =Encoder().to(device)
    loss_fn = DeepInfoMaxLoss().to(device)
    optim = Adam(encoder.parameters(),lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(),lr=1e-4)

    epoch_loss = []
    for epoch in range(totalepoch +1):
        batch = tqdm(train_loader,total=len(train_dataset)//batch_size)
        train_loss = []
        for x,target in batch: # 遍历数据集
            x = x.to(device)
            optim.zero_grad()
            loss_optim.zero_grad()
            y,M = encoder(x) # 用编码器生成特征图和特征向量
            # 制作边缘分布样本
            M_prime = torch.cat((M[1:],M[0].unsqueeze(0)),dim=0)
            loss =loss_fn(y,M,M_prime) # 计算损失
            train_loss.append(loss.item())
            batch.set_description(str(epoch) + ' Loss:%.4f'% np.mean(train_loss[-20:]))
            loss.backward()
            optim.step() # 调用编码器优化器
            loss_optim.step() # 调用判别器优化器
        if epoch % 10 == 0 : # 保存模型
            root = Path(r'./DIMmodel/')
            enc_file = root / Path('encoder' + str(epoch) + '.pth')
            loss_file = root / Path('loss' + str(epoch) + '.pth')
            enc_file.parent.mkdir(parents=True, exist_ok=True)
            torch.save(encoder.state_dict(), str(enc_file))
            torch.save(loss_fn.state_dict(), str(loss_file))
        epoch_loss.append(np.mean(train_loss[-20:])) # 收集训练损失
    plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r') # 损失可视化
    plt.show()

result:

 

3.7 Load the model and search for images------DIM_CIRFAR_loadpath.py

import torch
import torch.nn.functional as F
from tqdm import tqdm
import random

# 功能介绍:载入编码器模型,对样本集中所有图片进行编码,随机取一张图片,找出与该图片最接近与最不接近的十张图片
#
# 引入本地库
#引入本地代码库
from DIM_CIRFAR_train import ( train_loader,train_dataset,totalepoch,device,batch_size,imshowrow, Encoder)

# 加载模型
model_path = r'./DIMmodel/encoder%d.pth'% (totalepoch)
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load(model_path,map_location=device))

# 加载模型样本,并调用编码器生成特征向量
batchesimg = []
batchesenc = []
batch = tqdm(train_loader,total=len(train_dataset)//batch_size)
for images ,target in batch :
    images = images.to(device)
    with torch.no_grad():
        encoded,features = encoder(images) # 调用编码器生成特征向量
    batchesimg.append(images)
    batchesenc.append(encoded)
# 将样本中的图片与生成的向量沿第1维度展开
batchesenc = torch.cat(batchesenc,axis = 0)
batchesimg = torch.cat(batchesimg,axis = 0)
# 验证向量的搜索功能
index = random.randrange(0,len(batchesenc)) # 随机获取一个索引,作为目标图片
batchesenc[index].repeat(len(batchesenc),1) # 将目标图片的特征向量复制多份
# 使用F.mse_loss()函数进行特征向量间的L2计算,传入了参数reduction='none',这表明对计算后的结果不执行任何操作。如果不传入该参数,那么函数默认会对所有结果取平均值(常用在训练模型场景中)
l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc),1),batchesenc,reduction='none').sum(1) # 计算目标图片与每个图片的L2距离
findnum = 10 # 设置查找图片的个数
# 使用topk()方法获取L2距离最近、最远的图片。该方法会返回两个值,第一个是真实的比较值,第二个是该值对应的索引。
_,indices = l2_dis.topk(findnum,largest=False ) # 查找10个最相近的图片
_,indices_far = l2_dis.topk(findnum,) # 查找10个最不相关的图片
# 显示结果
indices = torch.cat([torch.tensor([index]).to(device),indices])
indices_far = torch.cat([torch.tensor([index]).to(device),indices_far])
rel = torch.cat([batchesimg[indices],batchesimg[indices_far]],axis = 0)
imshowrow(rel.cpu() ,nrow=len(indices))
# 结果显示:结果有两行,每行的第一列是目标图片,第一行是与目标图片距离最近的搜索结果,第二行是与目标图片距离最远的搜索结果。

 4 Code overview

4.1 Train the model: DIM_CIRFAR_train.py

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.optim import Adam
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets.cifar import CIFAR10
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torchvision.transforms import ToPILImage
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 1.1 获取数据集并显示数据集
# 指定运算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# 加载数据集
batch_size = 512
data_dir = r'./cifar10/'
# 将CIFAR10数据集下载到本地:共有三份文件,标签说明文件batches.meta,训练样本集data_batch_x(一共五个,包含10000条训练样本),测试样本test.batch
train_dataset = CIFAR10(data_dir,download=True,transform=ToTensor())
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,pin_memory=torch.cuda.is_available())
print("训练样本个数:",len(train_dataset))
# 定义函数用于显示图片
def imshowrow(imgs,nrow):
    plt.figure(dpi=200) # figsize=(9,4)
    # ToPILImage()调用PyTorch的内部转换接口,实现张量===>PLImage类型图片的转换。
    # 该接口主要实现。(1)将张量的每个元素乘以255。(2)将张量的数据类型由FloatTensor转化成uint8。(3)将张量转化成NumPy的ndarray类型。(4)对ndarray对象执行transpose(1,2,0)的操作。(5)利用Image下的fromarray()函数,将ndarray对象转化成PILImage形式。(6)输出PILImage。
    _img = ToPILImage()(torchvision.utils.make_grid(imgs,nrow=nrow)) # 传入PLlmage()接口的是由torchvision.utis.make_grid接口返回的张量对象
    plt.axis('off')
    plt.imshow(_img)
    plt.show()

# 定义标签与对应的字符
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 获取一部分样本用于显示
sample = iter(train_loader)
images,labels = sample.next()
print("样本形状:",np.shape(images))
print('样本标签:',','.join('%2d:%-5s' % (labels[j],classes[labels[j]]) for j in range(len(images[:10]))))
imshowrow(images[:10],nrow=10)

# 1.2 定义DIM模型
class Encoder(nn.Module): # 通过多个卷积层对输入数据进行编码,生成64维特征向量
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1) # 输出尺寸29
        self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1) # 输出尺寸26
        self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1) # 输出尺寸23
        self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1) # 输出尺寸20
        self.l1 = nn.Linear(512*20*20, 64)
        # 定义BN层
        self.b1 = nn.BatchNorm2d(128)
        self.b2 = nn.BatchNorm2d(256)
        self.b3 = nn.BatchNorm2d(512)

    def forward(self, x):
        h = F.relu(self.c0(x))
        features = F.relu(self.b1(self.c1(h)))#输出形状[b 128 26 26]
        h = F.relu(self.b2(self.c2(features)))
        h = F.relu(self.b3(self.c3(h)))
        encoded = self.l1(h.view(x.shape[0], -1))# 输出形状[b 64]
        return encoded, features

class DeepInfoMaxLoss(nn.Module): # 实现全局、局部、先验判别器模型的结构设计,合并每个判别器的损失函数,得到总的损失函数
    def __init__(self,alpha=0.5,beta=1.0,gamma=0.1):
        super().__init__()
        # 初始化损失函数的加权参数
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        # 定义局部判别模型
        self.local_d = nn.Sequential(
            nn.Conv2d(192,512,kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512,512,kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512,1,kernel_size=1)
        )
        # 定义先验判别器模型
        self.prior_d = nn.Sequential(
            nn.Linear(64,1000),
            nn.ReLU(True),
            nn.Linear(1000,200),
            nn.ReLU(True),
            nn.Linear(200,1),
            nn.Sigmoid() # 在定义先验判别器模型的结构时,最后一层的激活函数用Sigmoid函数。这是原始GAN模型的标准用法(可以控制输出值的范围为0-1),是与损失函数配套使用的。
        )
        # 定义全局判别器模型
        self.global_d_M = nn.Sequential(
            nn.Conv2d(128,64,kernel_size=3), # 输出形状[b,64,24,24]
            nn.ReLU(True),
            nn.Conv2d(64,32,kernel_size=3), # 输出形状 [b,32,32,22]
            nn.Flatten(),
        )
        self.global_d_fc = nn.Sequential(
            nn.Linear(32*22*22+64,512),
            nn.ReLU(True),
            nn.Linear(512,512),
            nn.ReLU(True),
            nn.Linear(512,1)
        )

    def GlobalD(self, y, M):
        h = self.global_d_M(M)
        h = torch.cat((y, h), dim=1)
        return self.global_d_fc(h)
    def forward(self,y,M,M_prime):
        # 复制特征向量
        y_exp = y.unsqueeze(-1).unsqueeze(-1)
        y_exp = y_exp.expand(-1,-1,26,26) # 输出形状[b,64,26,26]
        # 按照特征图的像素连接特征向量
        y_M = torch.cat((M,y_exp),dim=1) # 输出形状[b,192,26,26]
        y_M_prime = torch.cat((M_prime,y_exp),dim=1)# 输出形状[b,192,26,26]
        # 计算局部互信息---互信息的计算
        Ej = -F.softplus(-self.local_d(y_M)).mean() # 联合分布
        Em = F.softplus(self.local_d(y_M_prime)).mean() # 边缘分布
        LOCAL = (Em - Ej) * self.beta # 最大化互信息---对互信息执行了取反操作。将最大化问题变为最小化问题,在训练过程中,可以使用最小化损失的方法进行处理。
        # 计算全局互信息---互信息的计算
        Ej = -F.softplus(-self.GlobalD(y, M)).mean() # 联合分布
        Em = F.softplus(self.GlobalD(y, M_prime)).mean() # 边缘分布
        GLOBAL = (Em - Ej) * self.alpha # 最大化互信息---对互信息执行了取反操作。将最大化问题变为最小化问题,在训练过程中,可以使用最小化损失的方法进行处理。
        # 计算先验损失
        prior = torch.rand_like(y) # 获得随机数
        term_a = torch.log(self.prior_d(prior)).mean() # GAN损失
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        PRIOR = -(term_a + term_b) * self.gamma # 最大化目标分布---实现了判别器的损失函数。判别器的目标是将真实数据和生成数据的分布最大化,因此,也需要取反,通过最小化损失的方法来实现。
        return LOCAL + GLOBAL + PRIOR

# #### 在训练过程中,梯度可以通过损失函数直接传播到编码器模型,进行联合优化,因此,不需要对编码器额外进行损失函数的定义!

# 1.3 实例化DIM模型并训练:实例化模型按照指定次数迭代训练。在制作边缘分布样本时,将批次特征图的第1条放到最后,以使特征图与特征向量无法对应,实现与按批次打乱顺序等同的效果。
totalepoch = 10 # 指定训练次数
if __name__ == '__main__':
    encoder =Encoder().to(device)
    loss_fn = DeepInfoMaxLoss().to(device)
    optim = Adam(encoder.parameters(),lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(),lr=1e-4)

    epoch_loss = []
    for epoch in range(totalepoch +1):
        batch = tqdm(train_loader,total=len(train_dataset)//batch_size)
        train_loss = []
        for x,target in batch: # 遍历数据集
            x = x.to(device)
            optim.zero_grad()
            loss_optim.zero_grad()
            y,M = encoder(x) # 用编码器生成特征图和特征向量
            # 制作边缘分布样本
            M_prime = torch.cat((M[1:],M[0].unsqueeze(0)),dim=0)
            loss =loss_fn(y,M,M_prime) # 计算损失
            train_loss.append(loss.item())
            batch.set_description(str(epoch) + ' Loss:%.4f'% np.mean(train_loss[-20:]))
            loss.backward()
            optim.step() # 调用编码器优化器
            loss_optim.step() # 调用判别器优化器
        if epoch % 10 == 0 : # 保存模型
            root = Path(r'./DIMmodel/')
            enc_file = root / Path('encoder' + str(epoch) + '.pth')
            loss_file = root / Path('loss' + str(epoch) + '.pth')
            enc_file.parent.mkdir(parents=True, exist_ok=True)
            torch.save(encoder.state_dict(), str(enc_file))
            torch.save(loss_fn.state_dict(), str(loss_file))
        epoch_loss.append(np.mean(train_loss[-20:])) # 收集训练损失
    plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r') # 损失可视化
    plt.show()

4.2 Load the model: DIM_CIRFAR_loadpath.py

import torch
import torch.nn.functional as F
from tqdm import tqdm
import random

# 功能介绍:载入编码器模型,对样本集中所有图片进行编码,随机取一张图片,找出与该图片最接近与最不接近的十张图片
#
# 引入本地库
#引入本地代码库
from DIM_CIRFAR_train import ( train_loader,train_dataset,totalepoch,device,batch_size,imshowrow, Encoder)

# 加载模型
model_path = r'./DIMmodel/encoder%d.pth'% (totalepoch)
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load(model_path,map_location=device))

# 加载模型样本,并调用编码器生成特征向量
batchesimg = []
batchesenc = []
batch = tqdm(train_loader,total=len(train_dataset)//batch_size)
for images ,target in batch :
    images = images.to(device)
    with torch.no_grad():
        encoded,features = encoder(images) # 调用编码器生成特征向量
    batchesimg.append(images)
    batchesenc.append(encoded)
# 将样本中的图片与生成的向量沿第1维度展开
batchesenc = torch.cat(batchesenc,axis = 0)
batchesimg = torch.cat(batchesimg,axis = 0)
# 验证向量的搜索功能
index = random.randrange(0,len(batchesenc)) # 随机获取一个索引,作为目标图片
batchesenc[index].repeat(len(batchesenc),1) # 将目标图片的特征向量复制多份
# 使用F.mse_loss()函数进行特征向量间的L2计算,传入了参数reduction='none',这表明对计算后的结果不执行任何操作。如果不传入该参数,那么函数默认会对所有结果取平均值(常用在训练模型场景中)
l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc),1),batchesenc,reduction='none').sum(1) # 计算目标图片与每个图片的L2距离
findnum = 10 # 设置查找图片的个数
# 使用topk()方法获取L2距离最近、最远的图片。该方法会返回两个值,第一个是真实的比较值,第二个是该值对应的索引。
_,indices = l2_dis.topk(findnum,largest=False ) # 查找10个最相近的图片
_,indices_far = l2_dis.topk(findnum,) # 查找10个最不相关的图片
# 显示结果
indices = torch.cat([torch.tensor([index]).to(device),indices])
indices_far = torch.cat([torch.tensor([index]).to(device),indices_far])
rel = torch.cat([batchesimg[indices],batchesimg[indices_far]],axis = 0)
imshowrow(rel.cpu() ,nrow=len(indices))
# 结果显示:结果有两行,每行的第一列是目标图片,第一行是与目标图片距离最近的搜索结果,第二行是与目标图片距离最远的搜索结果。

Guess you like

Origin blog.csdn.net/qq_39237205/article/details/123791810