Is attention mechanism better than matrix factorization? ——IS ATTENTION BETTER THAN MATRIX DECOMPOSITION?

Original link:https://openreview.net/pdf?id=1FvkSpWosOlicon-default.png?t=N7T8https://openreview.net/pdf?id =1FvkSpWosOl

代码库:​​​​​​​​​​​​​​GitHub - Gsunshine/Enjoy-Hamburger: [ICLR 2021 top 3%] Is Attention Better Than Matrix Decomposition?[ICLR 2021 top 3%] Is Attention Better Than Matrix Decomposition? - GitHub - Gsunshine/Enjoy-Hamburger: [ICLR 2021 top 3%] Is Attention Better Than Matrix Decomposition?icon-default.png?t=N7T8https://github.com/Gsunshine/Enjoy-Hamburger

0. Summary

        As an important part of modern deep learning, especially the self-attention mechanism, it plays a vital role in global correlation discovery. However, are hand-designed attention mechanisms irreplaceable when modeling global context? Our interesting finding is that self-attention is no better at encoding long-range dependencies than matrix factorization (MD) models developed 20 years ago in terms of performance and computational cost. We model the global context problem as a low-rank completion problem and show that its optimization algorithm can help design global information blocks. This paper proposes a family of hamburgers, in which we utilize an optimization algorithm to solve MD, decomposing the input representation into submatrices and reconstructing a low-rank embedding. When the gradients backpropagated through MD are carefully processed, hamburgers with different MDs can perform well when competing with the popular global context module self-attention. We conduct comprehensive experiments on vision tasks that require learning global context, including semantic segmentation and image generation, significantly outperforming self-attention and its variants. Code has been provided.

1 Introduction

        Since self-attention and Transformer (Vaswani et al., 2017) have shown obvious advantages in capturing long-distance dependencies, attention mechanisms have been widely used in computer vision (Wang et al., 2018; Zhang et al., 2019a ) and natural language processing (Devlin et al., 2019) for global information mining. However, are hand-designed attention mechanisms irreplaceable when modeling global context? This article focuses on a new approach to designing global context modules. The key idea is that if we formalize an inductive bias like a global context as an objective function, then an optimization algorithm that minimizes the objective function can build a computational graph, the architecture we need in the network. We concretize this idea by developing a corresponding module for self-attention, the most representative global context module. Considering that extracting global information in a network is like finding a dictionary and corresponding encoding to capture intrinsic correlations, we model context discovery as a low-rank completion problem of input tensors and solve it via matrix factorization. This paper proposes a global correlation module, namely the hamburger, to recover a clean low-rank signal subspace by employing matrix factorization to decompose the learned representation into submatrices. Iterative optimization algorithms that solve matrix factorization define the central computational graph, the architecture of the hamburger. Our work utilizes matrix factorization models as the basis of the hamburger, including Vector Quantization (VQ) (Gray & Neuhoff, 1998), Concept Decomposition (CD) (Dhillon & Modha, 2001) and non-negative matrices Decomposition (Non-negative Matrix Factorization, NMF) (Lee & Seung, 1999). In addition, we adopt the truncated BPTT (Back-Propagation Through Time) algorithm (Werbos et al., 1990) instead of the BPTT algorithm directly applied to iterative optimization, that is, one-step gradient, to effectively backpropagate the gradient. We demonstrate the advantages of hamburger in fundamental vision tasks where global information proves crucial, including semantic segmentation and image generation. Experiments demonstrate that optimally designed hamburgers can compete with state-of-the-art attention models while avoiding unstable gradients backpropagated through MD’s iterative computation graph. Hamburger sets new state-of-the-art records for semantic segmentation on the PASCAL VOC dataset (Everingham et al., 2010) and PASCAL Context dataset (Mottaghi et al., 2014), and on ImageNet (Deng et al., 2009) surpasses existing attention modules in large-scale image generation.

The contributions of this article are as follows:

  • We demonstrate a white-box approach to designing global information modules by transforming an optimization algorithm that minimizes an objective function into an architecture, where global correlation is modeled as a low-rank completion problem.
  • We propose Hamburger, a lightweight and powerful global context module with O(n) complexity that surpasses various attention modules on tasks such as semantic segmentation and image generation.
  • We found that the main obstacle to applying MD in networks is the unstable reverse gradient in its iterative optimization algorithm. As a practical solution, our proposed one-step gradient helps to train Hamburger using MD.

2. Methodology

2.1.Preheating

        Matrix decomposition plays a key role in the proposed Hamburger. We first review the idea of ​​matrix decomposition. A common view is that matrix factorization decomposes the observed matrix into the product of several submatrices, such as singular value decomposition. However, a more illuminating view is that matrix factorization acts as the inverse process of generation, breaking apart the atoms that make up complex data through a process of hypothesis generation. Matrix factorization can recover the underlying structure of the observed data by reconstructing the original matrix. Assuming that the given data is arranged into a large matrix X in. That is to say, there is a dictionary matrix D=[d1;···;dr]∈R^d×r and the corresponding encoding C=[c1;···;cn]∈R^r×n, such that X can Expressed as        Where, X̄ ∈ Rd×n is the output low-rank reconstruction matrix, and E ∈ Rd×n is the noise matrix to be discarded. We assume that the recovered matrix By assuming structures for matrices D, C, and E, different matrix decomposition methods can be obtained (Kolda & Bader, 2009; Udell et al., 2016). Matrix decomposition is usually modeled as an objective function with various constraints and solved by optimization algorithms. Classic applications include image denoising (Wright et al., 2009; Lu et al., 2014) and inpainting (Mairal et al. , 2010) and feature extraction (Zhang et al., 2012).

2.2. Proposed method

We focus on building global context modules for networks that do not require laborious hand-design. Before starting the discussion, we briefly review a representative hand-designed context block - the self-attention mechanism. Attention mechanisms aim to find a set of concepts from a large amount of unconscious context for further conscious reasoning (Xu et al., 2015; Bengio, 2017; Goyal et al., 2019). As a representative example, the self-attention mechanism (Vaswani et al., 2017) was proposed for learning long-range dependencies in machine translation.         Where, Q, K, V ∈ R^n×d are features projected from the input through linear transformation. The self-attention mechanism extracts global information by focusing on all markers simultaneously instead of a recurrent neural network that processes them one by one.         Although self-attention and its variants have achieved great success, researchers face two problems: first, the development of new global context modules based on self-attention usually requires manual design; second, is to explain how the current attention model works. This article bypasses both of these issues and finds a way to easily design global context modules through a well-defined white-box toolkit. We attempt to formalize human inductive biases (e.g., global context) into an objective function and solve this problem using optimization algorithms to design the architecture of the module. An optimization algorithm creates a computational graph, accepts some inputs, and ultimately outputs a solution. We apply the computational graph of the optimization algorithm to the core part of our context module. Based on this approach, we need to model the global information problem of the network as an optimization problem. Take convolutional neural network (CNN) as an example for further discussion. After we input the image into the network, the network outputs a tensor X∈R^C×H×W. Since the tensor can be viewed as a collection of HW C-dimensional superpixels, we expand the tensor into a matrix X∈R^C×HW. When the module learns long-range dependencies or global context, the hidden assumption is that there are inherent correlations between superpixels. For simplicity, we assume that superpixels are linearly related, which means that each superpixel in X can be expressed as a linear combination of basis vectors, and the elements of the basis vectors are usually much smaller than HW. In an ideal case, the global information hidden in X can be low-rank. However, due to the poor ability of traditional CNN to model global context (Wang et al., 2018; Zhang et al., 2019a), the learned X is usually interfered by redundant information or incompleteness. The above analysis proposes a potential method to model the global context, that is, by completing the expansion of the low-rank part X̄ in the matrix X, discarding the noisy part E, and using the classic matrix factorization model described in Eq. remainder and incompleteness. Therefore, we model learning global context as a low-rank completion problem and adopt matrix factorization as its solution. According to the concept in Section 2.1, the general objective function of matrix decomposition is         Where, L is the reconstruction loss, R1 and R2 are the regularization terms of the dictionary D and code C. Denote the optimization algorithm that minimizes Eq.(4) as M. M is the core architecture we use in the global context module. To help readers further understand this modeling process, we provide a more intuitive explanation in Appendix G. In later chapters, we introduce our global context block Hamburger and discuss the detailed MD model and optimization algorithm of M. Finally, we solve the gradient problem with backpropagation via matrix factorization.

2.2.1. Hamburger

        Hamburger consists of one piece of "ham" (matrix decomposition) and two pieces of "bread" (linear transformation). As the name suggests, Hamburger first maps the input Z∈R^{dz×n} into the feature space through a linear transformation Wl, that is, the "lower bread", and then uses matrix decomposition M to solve the low-rank signal subspace, corresponding to the "ham", Finally another linear transformation Wu is used to convert the extracted signal into an output, called the "upper bread".         Among them, matrix decomposition M is used to restore a clear latent structure and play a global nonlinear role. The detailed architecture of M, i.e., the optimization algorithm for decomposing X, is discussed in Section 2.2.2. Figure 1 depicts Hamburger’s architecture, which collaborates with the network through batch normalization (BN) (Ioffe & Szegedy, 2015), skip connections, and the final output Y.

2.2.2.Hams

        This section describes the structure of "ham", which is M in Eq.(5). As discussed in the previous section, by formalizing global information discovery as an optimization problem of MD, it is natural to combine algorithms for solving MD into M. M takes as input the output of the “bread below” and computes a low-rank reconstruction as its output, denoted as X and X̄ respectively.         We study two MD models to solve D and C and reconstruct Give Appendix B. Selected MD models are only briefly introduced as we focus on elucidating the importance of low-rank inductive biases and optimization-based design methods on global context modules rather than specific MD models. In the context of this article, it is preferable to treat the MD part as a whole, that is, M, and focus on how Hamburger demonstrates its advantages as a whole.

Vector Quantization (Vector Quantization, VQ) (Gray & Neuhoff, 1998) is a classic data compression algorithm that can be expressed as an optimization problem in the form of matrix decomposition: < /span>        In order to satisfy the non-negative constraint, before inputting X into NMF, we add the ReLU nonlinear function on it. We adopt the Multiplicative Update rule (Lee & Seung, 2001) to solve NMF, which guarantees convergence. As white-box global context modules, VQ, CD and NMF are intuitive and lightweight, showing significant efficiency. They are transformed into an optimization algorithm consisting mainly of matrix multiplications with a complexity of O(ndr), which is much lower than the O(n2d) complexity in self-attention, where r << n. All three types of MD are memory-friendly because they avoid generating a large n × n matrix as an intermediate variable, like the product of Q and K in self-attention in Eq. (3). In the later section, our experiments demonstrate that although the architecture of M is created through optimization and looks different from the classic dot product self-attention, MD is at least comparable to self-attention. If we impose non-negative constraints on the dictionary D and encoding C, we get Non-negative Matrix Factorization (NMF) (Lee & Seung, 1999):         Where ei is the unit basis vector, ei =[0;···;1;···;0]>ith. To minimize the objective function in Eq. (8), the solution is the K-means algorithm (Gray & Neuhoff, 1998). However, to ensure that VQ is differentiable, we replace the hard arg min and Euclidean distance with softmax and cosine similarity, resulting in Alg.1, where cosine(D;X) is a similarity matrix whose elements satisfy cosine( D;X)ij =kdd>ikkxxjk, softmax is applied column by column, and T is the temperature parameter. When T is not equal to 0, we can get a hard allocation through a one-hot vector.

2.3. One-step gradient

        Since M involves an optimization algorithm as its calculation graph, the key to integrating it into the network is how to iterate the algorithm for gradient backpropagation. The optimized RNN-like behavior suggests the Back-Propagation Through Time (BPTT) algorithm (Werbos et al., 1990) as the standard choice for distinguishing iterative processes. Let’s first review the BPTT algorithm. However, in practice, the unstable gradient brought by BPTT is harmful to Hamburger's performance. Therefore, we build an abstract model to analyze the shortcomings of BPTT and try to find a practical solution when considering the nature of MD as an optimization algorithm. As shown in Figure 2, x, y and ht represent the input, output and intermediate results of time step t respectively, and F and G are operators. At each time step, the model receives the same input x processed by the underlying network. All intermediate results hi are discarded, and only the output ht of the last step is passed through G to generate output y. In the BPTT algorithm, according to the chain rule, the gradient from output y to input x can be obtained. A thought experiment is to consider that t goes to infinity, leading to completely convergent results h∗ and the infinite term in equation (12). We assume that F and G have Lipschitz continuity with constant Lh for h, Lipschitz continuity with constant Lx and LG for x and LG, and Lh < 1. It is important to note that these assumptions apply to many optimization or numerical methods. Then we have:         When Lh is close to 0, it is easy to cause the gradient relative to h0 to disappear. When Lh is close to 1, it is easy to cause the gradient relative to x to explode. In addition, the Jacobian matrix @@yx will have an ill-conditioned term (I −@@hF∗)−1 when Lh is close to 1, which is the maximum eigenvalue of @@Fh, that is, the Lipschitz constant of F relative to h is close to 1, And its minimum eigenvalue is usually close to 0, thus limiting the ability of the gradient to search for good generalization solutions in the parameter space. The irregular scale and spectrum of gradients backpropagated through the optimization algorithm indicate that directly applying BPTT to Hamburger is not feasible, which was confirmed in experiments using the same ablation settings as in Section 3.1 (see Table 1). This analysis inspired us to a possible solution. It should be noted that in the BPTT algorithm, there are multiplications of multiple Jacobian matrices @hj @hj−@hj @hj−1 and summations of infinite series, resulting in the uncontrollable scale of the gradient. This inspires us to remove some minor terms while retaining the dominant terms to ensure that the direction of the gradient is roughly correct. Consider considering the terms of equation (12) as a series, that is, f@@hyt Qt j−=1t−i @@hhj+1 j @h@tx−i gi, if the scale of its terms is measured by the operator norm decays exponentially, then it makes sense to use the first term of this series to approximate the gradient. The first term of the gradient comes from the last step of optimization, resulting in the one-step gradient.         According to Proposition 2, when t tends to infinity, the one-step gradient is a linear approximation of the BPTT algorithm. It is easy to implement, requiring only no_grad operation in PyTorch (Paszke et al., 2019), or stop_gradient operation in TensorFlow (Abadi et al., 2016), and reduces the time and space complexity from O(t) of BPTT to O(1). We tested adding more terms to the gradient, but the performance was worse than using a one-step gradient. According to the experimental results, one-step gradient is acceptable for gradient backpropagation through MD.

Table 2: Ablation experiments performed on Hamburger’s NMF Ham component.

3. Experiment

        ​ ​ ​ In this section, we provide experimental results to demonstrate the effectiveness of the above technique. We selected two visual tasks that require global information and attention mechanisms for experiments, including semantic segmentation (more than 50 papers using attention) and deep generative models such as GAN (attention since SAGAN, most state-of-the-art GANs all use self-attention) (Zhang et al., 2019a). These two tasks are highly competitive enough to compare Hamburger and self-attention. Ablation studies demonstrate the importance of MD in Hamburger and the necessity of one-step gradients. We emphasize the superiority of Hamburger in modeling global context, both in terms of performance and computational cost.

3.1.Ablation experiment

        We chose to conduct all ablation experiments on the PASCAL VOC dataset (Everingham et al., 2010) for semantic segmentation and report the mIoU on the validation set as the best (mean) over 5 runs. For all ablation experiments, ResNet-50 (He et al., 2016) is the backbone network with an output stride of 16. We use a 3 × 3 convolutional layer with BN (Ioffe & Szegedy, 2015) and ReLU to reduce the number of channels from 2048 to 512, and then add Hamburger at an attention location common in semantic segmentation. See Appendix E.1 for detailed training setup.

Bread and Ham We performed melting experiments on each part of the Hamburger. Removing MD(ham) results in a severe performance degradation, proving the importance of MD. Even if only parameterless MD (only ham) is added, the performance can be significantly improved. Parameterization also helps Hamburger process extracted features. Bread, especially top bread, contributes considerably to performance.

        It is worth noting that there is no simple linear relationship between the latent dimensions d and r and the performance measured by mIoU, although d = 8r is a satisfactory choice. Experiments show that it performs well even with r = 8, suggesting that it can be very cheap for modeling global context.

3.2. Observe Hamburger carefully

        To understand the behavior of Hamburger in the network, we visualized the spectrum of representations before and after Hamburger on the PASCAL VOC validation set. The input and output tensors are expanded into RC×HW. In the expanded matrix, the cumulative ratio of the sum of the squares of the largest r singular values ​​to the sum of the squares of the total singular values ​​is shown in Figure 5. Due to low-rank reconstruction, truncated spectra are often observed in the results of classical matrix factorization models. In the network, Hamburger also improves the concentration of energy through skip connections while retaining information-rich details. Additionally, we visualize the feature maps before and after Hamburger in Figure 6. MD helps Hamburger learn interpretable global information by zeroing out uninformative channels, removing irregular noise, and refining details based on context.

3.3. Comparison with attention

        This section demonstrates the advantages of MD-based Hamburger over attention-related context modules in terms of computational cost, memory consumption, and inference time. We combine Hamburger (Ham) with self-attention (SA) (Vaswani et al., 2017), dual attention (DA) module in DANet (Fu et al., 2019), and dual attention module in A2 Net (Chen et al., 2018b ), the APC module in APCNet (He et al., 2019b), the DM module in DMNet (He et al., 2019a), and the ACF module in CFNet (Zhang et al., 2019b) are compared, and the processing of a Zhang is reported in Table 3 Parameters and costs for quantity Z (size 1×512×128×128). In practical applications, excessive memory usage is a key bottleneck in cooperating with attention. Therefore, we also provide GPU load and inference time on NVIDIA TITAN Xp. Overall, Hamburger is computationally and memory lightweight compared to attention-related global context modules.

3.4. Semantic segmentation

        We benchmark Hamburger on the PASCAL VOC dataset (Everingham et al., 2010) and the PASCAL Context dataset (Mottaghi et al., 2014) and compare it with state-of-the-art attention models. We use ResNet-101 (He et al., 2016) as our backbone network. The output stride of the backbone network is 8. The segmentation head was the same as in the ablation experiments. In ablation studies, NMF generally performs better than CD and VQ (see Table 1). Therefore, we mainly test NMF in subsequent experiments. We use HamNet to represent ResNet with Hamburger. The results on the PASCAL VOC test set and PASCAL Context validation set are shown in Table 4 and Table 5 respectively. We label all attention-based models as ∗, where diverse attentions make up the segmentation head. Although semantic segmentation is a saturated task and most modern published work has approximate performance, Hamburger shows considerable improvements over previous state-of-the-art attention modules.

3.5. Image generation

        In deep generative models (such as GANs), attention exists as a global context description block. Since SAGAN (Zhang et al., 2019a), most state-of-the-art GANs for conditional image generation integrate self-attention into their architecture, such as BigGAN (Brock et al., 2018), S3GAN (Lucić et al., 2019) and LOGAN (Wu et al., 2019). Benchmarking MD-based Hamburger on the challenging ImageNet (Deng et al., 2009) conditional image generation task is convincing. We conducted experiments to compare the performance of Hamburger and self-attention on ImageNet 128×128. In the generator and discriminator with a feature resolution of 32×32, self-attention is replaced by NMF Hamburger, named HamGAN-baby. HamGAN achieves considerable improvements over SAGAN on Frechet Inception Distance (FID) (Heusel et al., 2017). Additionally, we compare Hamburger using the code base of a recently developed attention variant Your Local GAN ​​(YLG) (Daras et al., 2020) and the same training settings, named HamGAN-strong. HamGAN-strong provides over 5% improvement on FID while being 15% faster in total training time and 3.6x faster in module time (1.54 iters/sec for HamGAN, 1.31 iters/sec for YLG without any context module The average is 1.65 iters/sec, based on 1000 iterations). These experiments are conducted on the same TPUv3 training platform.

4.Related work

        In the field of deep learning, the past five years have witnessed the great success of attention mechanisms (Bahdanau et al., 2015; Mnih et al., 2014; Xu et al., 2015; Luong et al., 2015). Roughly speaking, the attention mechanism adaptively generates target weights for attention based on needs. It has various architectures, the most famous of which is dot product self-attention (Vaswani et al., 2017). Attention mechanisms have a wide range of application fields, from single source (Lin et al., 2017) to multi-source input (Luong et al., 2015; Parikh et al., 2016), from global information discovery (Wang et al., 2018; Zhang et al., 2019a) to Local feature extraction (Dai et al., 2017; Parmar et al., 2019). Previous researchers have tried to explain the effectiveness of the attention mechanism from multiple aspects. Capture long-range dependencies (Wang et al., 2018), sequentially decompose visual scenes (Eslami et al., 2016; Kosiorek et al., 2018), infer the relationship between parts and the whole (Sabour et al., 2017; Hinton et al., 2018), simulate objects The interaction between learning environments (Greff et al., 2017; van Steenkiste et al., 2018) and the dynamics of the learning environment (Goyal et al., 2019) are often considered as potential mechanisms of attention.

        From a biological perspective, a common view is that attention models the emergence of focus in many unconscious situations (Xu et al., 2015). Some works attempt to explain the attention mechanism by visualizing or attacking attention weights (Serrano and Smith, 2019; Jain and Wallace, 2019; Wiegreffe and Pinter, 2019), while others formalize attention as a non-local operation (Wang et al., 2018) or diffusion model (Tao et al., 2018; Lu et al., 2019), or through the maximum expectation algorithm (Greff et al., 2017; Hinton et al., 2018; Li et al., 2019) or variational inference (Eslami et al., 2016) Building attention-like models on hybrid models. The connection between Transformer and graph neural networks has also been discussed (Liang et al., 2018; Zhang et al., 2019c). Overall, the discussion of attention remains far from consensus or unanimous conclusions.

        Recent research uses low-rank approximation in computer vision (Chen et al., 2018b; Zhu et al., 2019; Chen et al., 2019; Li et al., 2019) and natural language processing (Mehta et al., 2019; Katharopoulos et al., 2020; Wang et al., 2020 ; Song et al., 2020), an efficient attention module was developed. Technically speaking, low-rank approximation usually targets the correlation matrix, that is, the product of Q and K after the softmax operation. The product of two smaller matrices is used to approximate the correlation matrix, and the associativity law is applied to save memory and computational cost, where the approximation Involving kernel functions or other similarity functions. Other studies (Babiloni et al., 2020; Ma et al., 2019) strive to formalize attention into tensor form, but may generate large intermediate variables. This article does not approximate attention or make it efficient. This paper treats modeling global context as a low-rank completion problem. Computational and memory efficiency are by-products of the low-rank assumption on clean signal subspaces and optimization algorithms as architecture.

        In deep learning, the combination of matrix factorization and deep learning has a long history. Researchers reduce parameters in the network by factoring the weights, including softmax layers (Sainath et al., 2013), convolutional layers (Zhong et al., 2019), and embedding layers (Lan et al., 2019). Tariyal et al. (2016) tried to build deep dictionary learning for feature extraction and trained the model through greedy. This paper attempts to factorize the representation to recover clean signal subspaces as global context, and provides a new approach to modeling long-range dependencies through matrix factorization.

5. Summary

        This paper studies modeling long-range dependencies in networks. We formalize the process of learning global context as a low-rank completion problem. Inspired by low-rank formalization, we develop the Hamburger module based on the widely studied matrix factorization model. Hamburger's core architecture is naturally defined by the computational graph created by its optimization algorithm through a specific matrix decomposition objective function. Hamburger learns interpretable global context and improves spectral concentration by denoising and completing its inputs. Surprisingly, when inverse gradients are handled carefully, even a simple matrix factorization proposed 20 years ago can be as powerful as self-attention in challenging visual tasks such as semantic segmentation and image generation, and Lightweight, fast and memory efficient. We plan to extend Hamburger into the field of natural language processing by integrating positional information and designing Transformer-like decoders, establishing a theoretical foundation for one-step gradient techniques, or finding better ways to differentiate between matrix factorizations and integrating advanced matrix factorizations in the future method.

# -*- coding: utf-8 -*-
"""
Hamburger for Pytorch

@author: Gsunshine
"""

from functools import partial

import numpy as np
import settings
import torch
from sync_bn.nn.modules import SynchronizedBatchNorm2d
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.batchnorm import _BatchNorm

norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM)


class ConvBNReLU(nn.Module):
    @classmethod
    def _same_paddings(cls, kernel_size):
        if kernel_size == 1:
            return 0
        elif kernel_size == 3:
            return 1

    def __init__(self, in_c, out_c,
                 kernel_size=1, stride=1, padding='same',
                 dilation=1, groups=1):
        super().__init__()

        if padding == 'same':
            padding = self._same_paddings(kernel_size)

        self.conv = nn.Conv2d(in_c, out_c,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation,
                              groups=groups,
                              bias=False)
        self.bn = norm_layer(out_c)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        
        return x
# -*- coding: utf-8 -*-
"""
Hamburger for Pytorch

@author: Gsunshine
"""

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.batchnorm import _BatchNorm

from .bread import ConvBNReLU, norm_layer
from .ham import get_hams


class HamburgerV1(nn.Module):
    def __init__(self, in_c, args=None):
        super().__init__()

        ham_type = getattr(args, 'HAM_TYPE', 'NMF')

        D = getattr(args, 'MD_D', 512)

        if ham_type == 'NMF':
            self.lower_bread = nn.Sequential(nn.Conv2d(in_c, D, 1),
                                             nn.ReLU(inplace=True))
        else:
            self.lower_bread = nn.Conv2d(in_c, D, 1)

        HAM = get_hams(ham_type)
        self.ham = HAM(args)
        
        self.upper_bread = nn.Sequential(nn.Conv2d(D, in_c, 1, bias=False),
                                         norm_layer(in_c))
        
        self.shortcut = nn.Sequential()
        
        self._init_weight()
        
        print('ham', HAM)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / N))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        shortcut = self.shortcut(x)

        x = self.lower_bread(x)
        x = self.ham(x)
        x = self.upper_bread(x)

        x = F.relu(x + shortcut, inplace=True)

        return x

    def online_update(self, bases):
        if hasattr(self.ham, 'online_update'):
            self.ham.online_update(bases)


class HamburgerV2(nn.Module):
    def __init__(self, in_c, args=None):
        super().__init__()

        ham_type = getattr(args, 'HAM_TYPE', 'NMF')

        C = getattr(args, 'MD_D', 512)

        if ham_type == 'NMF':
            self.lower_bread = nn.Sequential(nn.Conv2d(in_c, C, 1),
                                             nn.ReLU(inplace=True))
        else:
            self.lower_bread = nn.Conv2d(in_c, C, 1)

        HAM = get_hams(ham_type)
        self.ham = HAM(args)

        self.cheese = ConvBNReLU(C, C)
        self.upper_bread = nn.Conv2d(C, in_c, 1, bias=False)

        self.shortcut = nn.Sequential()

        self._init_weight()

        print('ham', HAM)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / N))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        shortcut = self.shortcut(x)

        x = self.lower_bread(x)
        x = self.ham(x)
        x = self.cheese(x)
        x = self.upper_bread(x)

        x = F.relu(x + shortcut, inplace=True)

        return x

    def online_update(self, bases):
        if hasattr(self.ham, 'online_update'):
            self.ham.online_update(bases)


class HamburgerV2Plus(nn.Module):
    def __init__(self, in_c, args=None):
        super().__init__()

        ham_type = getattr(args, 'HAM_TYPE', 'NMF')

        S = getattr(args, 'MD_S', 1)
        D = getattr(args, 'MD_D', 512)
        C = S * D

        self.dual = getattr(args, 'DUAL', True)
        if self.dual:
            C = 2 * C

        if ham_type == 'NMF':
            self.lower_bread = nn.Sequential(nn.Conv2d(in_c, C, 1),
                                             nn.ReLU(inplace=True))
        else:
            self.lower_bread = nn.Conv2d(in_c, C, 1)

        HAM = get_hams(ham_type)
        if self.dual:
            args.SPATIAL = True
            self.ham_1 = HAM(args)
            args.SPATIAL = False
            self.ham_2 = HAM(args)
        else:
            self.ham = HAM(args)

        factor = getattr(args, 'CHEESE_FACTOR', S)
        if self.dual:
            factor = 2 * factor

        self.cheese = ConvBNReLU(C, C // factor)
        self.upper_bread = nn.Conv2d(C // factor, in_c, 1, bias=False)

        zero_ham = getattr(args, 'ZERO_HAM', True)
        if zero_ham:
            coef_ham_init = 0.
        else:
            coef_ham_init = 1.

        self.coef_shortcut = nn.Parameter(torch.tensor([1.]))
        self.coef_ham = nn.Parameter(torch.tensor([coef_ham_init]))

        self.shortcut = nn.Sequential()

        self._init_weight()

        print('ham', HAM)
        print('dual', self.dual)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / N))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        shortcut = self.shortcut(x)

        x = self.lower_bread(x)

        if self.dual:
            x = x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:])
            x_1 = self.ham_1(x.narrow(1, 0, 1).squeeze(dim=1))
            x_2 = self.ham_2(x.narrow(1, 1, 1).squeeze(dim=1))
            x = torch.cat([x_1, x_2], dim=1)
        else:
            x = self.ham(x)
        x = self.cheese(x)
        x = self.upper_bread(x)
    
        x = self.coef_ham * x + self.coef_shortcut * shortcut
        x = F.relu(x, inplace=True)

        return x

    def online_update(self, bases):
        if hasattr(self.ham, 'online_update'):
            self.ham.online_update(bases)


def get_hamburger(version):
    burgers = {'V1':HamburgerV1,
               'V2':HamburgerV2,
               'V2+': HamburgerV2Plus}

    assert version in burgers

    return burgers[version]
# -*- coding: utf-8 -*-
"""
Hamburger for Pytorch

@author: Gsunshine
"""

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.batchnorm import _BatchNorm


class _MatrixDecomposition2DBase(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.spatial = getattr(args, 'SPATIAL', True)

        self.S = getattr(args, 'MD_S', 1)
        self.D = getattr(args, 'MD_D', 512)
        self.R = getattr(args, 'MD_R', 64)

        self.train_steps = getattr(args, 'TRAIN_STEPS', 6)
        self.eval_steps  = getattr(args, 'EVAL_STEPS', 7)

        self.inv_t = getattr(args, 'INV_T', 100)
        self.eta   = getattr(args, 'ETA', 0.9)

        self.rand_init = getattr(args, 'RAND_INIT', True)

        print('spatial', self.spatial)
        print('S', self.S)
        print('D', self.D)
        print('R', self.R)
        print('train_steps', self.train_steps)
        print('eval_steps', self.eval_steps)
        print('inv_t', self.inv_t)
        print('eta', self.eta)
        print('rand_init', self.rand_init)

    def _build_bases(self, B, S, D, R, cuda=False):
        raise NotImplementedError

    def local_step(self, x, bases, coef):
        raise NotImplementedError

    @torch.no_grad()
    def local_inference(self, x, bases):
        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        coef = torch.bmm(x.transpose(1, 2), bases)
        coef = F.softmax(self.inv_t * coef, dim=-1)

        steps = self.train_steps if self.training else self.eval_steps
        for _ in range(steps):
            bases, coef = self.local_step(x, bases, coef)

        return bases, coef

    def compute_coef(self, x, bases, coef):
        raise NotImplementedError

    def forward(self, x, return_bases=False):
        B, C, H, W = x.shape

        # (B, C, H, W) -> (B * S, D, N)
        if self.spatial:
            D = C // self.S
            N = H * W
            x = x.view(B * self.S, D, N)
        else:
            D = H * W
            N = C // self.S
            x = x.view(B * self.S, N, D).transpose(1, 2)

        if not self.rand_init and not hasattr(self, 'bases'):
            bases = self._build_bases(1, self.S, D, self.R, cuda=True)
            self.register_buffer('bases', bases)

        # (S, D, R) -> (B * S, D, R)
        if self.rand_init:
            bases = self._build_bases(B, self.S, D, self.R, cuda=True)
        else:
            bases = self.bases.repeat(B, 1, 1)

        bases, coef = self.local_inference(x, bases)

        # (B * S, N, R)
        coef = self.compute_coef(x, bases, coef)

        # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
        x = torch.bmm(bases, coef.transpose(1, 2))

        # (B * S, D, N) -> (B, C, H, W)
        if self.spatial:
            x = x.view(B, C, H, W)
        else:
            x = x.transpose(1, 2).view(B, C, H, W)

        # (B * H, D, R) -> (B, H, N, D)
        bases = bases.view(B, self.S, D, self.R)

        if not self.rand_init and not self.training and not return_bases:
            self.online_update(bases)

        # if not self.rand_init or return_bases:
        #     return x, bases
        # else:
        return x

    @torch.no_grad()
    def online_update(self, bases):
        # (B, S, D, R) -> (S, D, R)
        update = bases.mean(dim=0)
        self.bases += self.eta * (update - self.bases)
        self.bases = F.normalize(self.bases, dim=1)


class VQ2D(_MatrixDecomposition2DBase):
    def __init__(self, args):
        super().__init__(args)

    def _build_bases(self, B, S, D, R, cuda=False):
        if cuda:
            bases = torch.randn((B * S, D, R)).cuda()
        else:
            bases = torch.randn((B * S, D, R))

        bases = F.normalize(bases, dim=1)

        return bases

    @torch.no_grad()
    def local_step(self, x, bases, _):
        # (B * S, D, N), normalize x along D (for cosine similarity)
        std_x = F.normalize(x, dim=1)

        # (B * S, D, R), normalize bases along D (for cosine similarity)
        std_bases = F.normalize(bases, dim=1, eps=1e-6)

        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        coef = torch.bmm(std_x.transpose(1, 2), std_bases)

        # softmax along R
        coef = F.softmax(self.inv_t * coef, dim=-1)

        # normalize along N
        coef = coef / (1e-6 + coef.sum(dim=1, keepdim=True))

        # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
        bases = torch.bmm(x, coef)

        return bases, coef

    def compute_coef(self, x, bases, _):
        with torch.no_grad():
            # (B * S, D, N) -> (B * S, 1, N)
            x_norm = x.norm(dim=1, keepdim=True)

        # (B * S, D, N) / (B * S, 1, N) -> (B * S, D, N)
        std_x = x / (1e-6 + x_norm)

        # (B * S, D, R), normalize bases along D (for cosine similarity)
        std_bases = F.normalize(bases, dim=1, eps=1e-6)

        # (B * S, N, D)^T @ (B * S, D, R) -> (B * S, N, R)
        coef = torch.bmm(std_x.transpose(1, 2), std_bases)

        # softmax along R
        coef = F.softmax(self.inv_t * coef, dim=-1)

        return coef


class CD2D(_MatrixDecomposition2DBase):
    def __init__(self, args):
        super().__init__(args)

        self.beta = getattr(args, 'BETA', 0.1)
        print('beta', self.beta)

    def _build_bases(self, B, S, D, R, cuda=False):
        if cuda:
            bases = torch.randn((B * S, D, R)).cuda()
        else:
            bases = torch.randn((B * S, D, R))

        bases = F.normalize(bases, dim=1)

        return bases

    @torch.no_grad()
    def local_step(self, x, bases, _):
        # normalize x along D (for cosine similarity)
        std_x = F.normalize(x, dim=1)

        # (B * S, N, D) @ (B * S, D, R) -> (B * S, N, R)
        coef = torch.bmm(std_x.transpose(1, 2), bases)

        # softmax along R
        coef = F.softmax(self.inv_t * coef, dim=-1)

        # normalize along N
        coef = coef / (1e-6 + coef.sum(dim=1, keepdim=True))

        # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
        bases = torch.bmm(x, coef)

        # normalize along D
        bases = F.normalize(bases, dim=1, eps=1e-6)

        return bases, coef

    def compute_coef(self, x, bases, _):
        # [(B * S, R, D) @ (B * S, D, R) + (B * S, R, R)] ^ (-1) -> (B * S, R, R)
        temp = torch.bmm(bases.transpose(1, 2), bases) \
            + self.beta * torch.eye(self.R).repeat(x.shape[0], 1, 1).cuda()
        temp = torch.inverse(temp)

        # (B * S, D, N)^T @ (B * S, D, R) @ (B * S, R, R) -> (B * S, N, R)
        coef = x.transpose(1, 2).bmm(bases).bmm(temp)

        return coef


class NMF2D(_MatrixDecomposition2DBase):
    def __init__(self, args):
        super().__init__(args)

        self.inv_t = 1

    def _build_bases(self, B, S, D, R, cuda=False):
        if cuda:
            bases = torch.rand((B * S, D, R)).cuda()
        else:
            bases = torch.rand((B * S, D, R))

        bases = F.normalize(bases, dim=1)

        return bases

    @torch.no_grad()
    def local_step(self, x, bases, coef):
        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        numerator = torch.bmm(x.transpose(1, 2), bases)
        # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
        # Multiplicative Update
        coef = coef * numerator / (denominator + 1e-6)

        # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
        numerator = torch.bmm(x, coef)
        # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
        denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
        # Multiplicative Update
        bases = bases * numerator / (denominator + 1e-6)

        return bases, coef

    def compute_coef(self, x, bases, coef):
        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
        numerator = torch.bmm(x.transpose(1, 2), bases)
        # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
        # multiplication update
        coef = coef * numerator / (denominator + 1e-6)

        return coef


def get_hams(key):
    hams = {'VQ':VQ2D,
            'CD':CD2D,
            'NMF':NMF2D}

    assert key in hams

    return hams[key]
import math
import os.path as osp
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm

import settings
from hamburger import ConvBNReLU, get_hamburger
from sync_bn.nn.modules import SynchronizedBatchNorm2d

norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM)


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1,
                 downsample=None, previous_dilation=1):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = norm_layer(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride, dilation, dilation,
                               bias=False)
        self.bn2 = norm_layer(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
        self.bn3 = norm_layer(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.dilation = dilation
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, stride=8):
        self.inplanes = 128
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
            norm_layer(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            norm_layer(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False))

        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)

        if stride == 16:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(
                    block, 512, layers[3], stride=1, dilation=2, grids=[1,2,4])
        elif stride == 8:
            self.layer3 = self._make_layer(
                    block, 256, layers[2], stride=1, dilation=2)
            self.layer4 = self._make_layer(
                    block, 512, layers[3], stride=1, dilation=4, grids=[1,2,4])

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
                    grids=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        if grids is None:
            grids = [1] * blocks

        if dilation == 1 or dilation == 2:
            layers.append(block(self.inplanes, planes, stride, dilation=1,
                                downsample=downsample,
                                previous_dilation=dilation))
        elif dilation == 4:
            layers.append(block(self.inplanes, planes, stride, dilation=2,
                                downsample=downsample,
                                previous_dilation=dilation))
        else:
            raise RuntimeError('=> unknown dilation size: {}'.format(dilation))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes,
                                dilation=dilation*grids[i],
                                previous_dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet(n_layers, stride):
    layers = {
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
    }[n_layers]
    pretrained_path = {
        50:  osp.join(settings.MODEL_DIR, 'resnet50-ebb6acbb.pth'),
        101: osp.join(settings.MODEL_DIR, 'resnet101-2a57e44d.pth'),
        152: osp.join(settings.MODEL_DIR, 'resnet152-0d43d698.pth'),
    }[n_layers]

    net = ResNet(Bottleneck, layers=layers, stride=stride)
    state_dict = torch.load(pretrained_path)
    net.load_state_dict(state_dict, strict=False)

    return net


class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, reduction='none', ignore_index=-1):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss(weight, reduction=reduction,
                                   ignore_index=ignore_index)

    def forward(self, inputs, targets):
        loss = self.nll_loss(F.log_softmax(inputs, dim=1), targets)
        return loss.mean(dim=2).mean(dim=1)


class HamNet(nn.Module):
    def __init__(self, n_classes, n_layers):
        super().__init__()
        backbone = resnet(n_layers, settings.STRIDE)
        self.backbone = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
            backbone.layer1,
            backbone.layer2,
            backbone.layer3,
            backbone.layer4)

        C = settings.CHANNELS

        self.squeeze = ConvBNReLU(2048, C, 3)

        Hamburger = get_hamburger(settings.VERSION)
        self.hamburger = Hamburger(C, settings)
        
        self.align = ConvBNReLU(C, 256, 3)
        self.fc = nn.Sequential(nn.Dropout2d(p=0.1),
                                nn.Conv2d(256, n_classes, 1))

        # Put the criterion inside the model to make GPU load balanced
        self.crit = CrossEntropyLoss2d(ignore_index=settings.IGNORE_LABEL,
                                       reduction='none')

    def forward(self, img, lbl=None, size=None):
        x = self.backbone(img)

        x = self.squeeze(x)
        x = self.hamburger(x)
        x = self.align(x)
        x = self.fc(x)

        if size is None:
            size = img.size()[-2:]

        pred = F.interpolate(x, size=size, mode='bilinear', align_corners=True)

        if self.training and lbl is not None:
            loss = self.crit(pred, lbl)
            return loss
        else:
            return pred


def test_net():
    model = HamNet(n_classes=21, n_layers=50)
    model.eval()
    print(list(model.named_children()))
    image = torch.randn(1, 3, 513, 513)
    label = torch.zeros(1, 513, 513).long()
    pred = model(image, label)
    print(pred.size())


if __name__ == '__main__':
    test_net()

Guess you like

Origin blog.csdn.net/ADICDFHL/article/details/133556103