《ECKPN: Explicit Class Knowledge Propagation Network for Transductive Few-shot Learning》

Insert image description here
Published in CVPR2021! ! !
Paper link: https://arxiv.org/pdf/2106.08523.pdf
Code link: None

1. Question

In recent years, methods based on direct inference graphs have achieved great success in few-shot classification. However, most existing methods neglect to explore class-level knowledge, which is easily learned by humans from a few samples.

2. Contribution

1) Proposed for the first time a graph-based end-to-end small sample learning architecture that can explicitly learn rich class knowledge to guide graph reasoning for query samples 2) Established a multi-head sample
relationship to explore pairwise samples Fine-grained comparison between each other, which helps to learn richer class knowledge based on pairwise relationships.
3) Use the semantic embedding of class names to construct multi-modal knowledge representations of different classes to provide more differentiated knowledge for reasoning about query samples. .
4) Extensive experiments are conducted on four benchmarks (i.e. miniImageNet, tieredImageNet, CIFAR-FS and CUB-200-2011), and the results show that the method has better classification performance.

3. Method

Consider how to explicitly learn richer class knowledge to guide graph-based query sample reasoning. AsInsert image description here
shown in Figure 1, if we only utilize sample representations and relationships to perform few-shot classification tasks, we may misclassify query samples q For category 2. However, if we explicitly learn class-level knowledge representation to guide the reasoning process, we can correctly classify q because q is closer to the representation of class 1.

3.1 Problem description

Insert image description here

3.2 Explicit knowledge dissemination network

Insert image description here
Framework overview:
1) First, leverage support and query samples to build instance-level graphs.
2) Then, utilize the comparison module to update the sample representation based on the pairwise node relationships in the instance-level graph. In this module, the author constructs multi-head relationships to help model fine-grained sample relationships to learn rich sample representations.
3) Next, instance-level graphs are compressed into class-level graphs to explicitly explore class-level visual knowledge.
4) In the calibration module , class-level message passing operations are performed according to the relationship between classes, and class-level knowledge representation is updated. Since the semantic word embedding of classes can provide rich prior knowledge, we combine it with class-level visual knowledge to construct multi-modal class knowledge representation before calibrating module message passing.
5) Finally, combine class-level knowledge representation with instance-level sample representation to guide the reasoning of query samples.

  • Compare modules: multi-headed relationships for instance-level messaging
    For image iii , a deep CNN model is used as the skeleton to extract its visual features. (We follow the existing literature and combine the visual features of the image with its corresponding one-hot encoding as the initial node featurevi (0) ∈ R dv^{(0)}_i \in R^dvi(0)Rd . Since the labels of query samples are not available in inference generation, we set the elements of their one-hot encoding to 1/N, where N is the number of classes. )
    In each episode, we take the support set and query set samples as nodes to build the graphG = (V (0), A (0)) G = (V^{(0)}, A^{(0) })G=(V(0)A( 0 ) ), whereV ( 0 ) V^{(0)}V( 0 ) is the initial node characteristic matrix,A (0) A^{(0)}A( 0 ) is the initial adjacency matrix set representing the sample relationship. Existing work has shown thatvisual features always contain some concepts that can be grouped, that is, the feature dimensions of the same group represent similar knowledge. However, existing graph-based small sample learning methods usually directly use global visual features to calculate the similarity of samples to construct adjacency matrices, which cannot well represent fine-grained relationships. In this paper, we divide the visual features intoKKK个块(即 V ( l ) = [ V 1 ( l ) , V 2 ( l ) , … , V K ( l ) ] ∈ R r × d ) V^{(l)} = [V^{(l)}_1, V^{(l)}_2,…, V^{(l)}_K] \in R^{r \times d}) V(l)=[V1(l),V2(l)VK(l)]Rr × d ), calculate the similarity of each block to explore the multi-head relationship of the sample (i.e.KKK adjacency matricesV 1 ( l ) , V 2 ( l ) , … , VK ( l ) ∈ R r × r V^{(l)}_1, V^{(l)}_2, …, V^{ (l)}_K \in R^{r \times r}V1(l),V2(l)VK(l)Rr × r ), whererrr is the number of samples in each episode,[ ∗ , ∗ ] [*, *][ , ] is the splicing operation,lll is the llthin the figurel layer generating matrix. Note that each blockV i ( l ) V^{(l)}_iVi(l)The dimension of is d / K d/Kd / K . Also calculate the global relationship matrix A g ( l ) ∈ R r × r A^{(l)}_g \in R^{r \times r} basedon non-blocked visual featuresAg(l)Rr × r .
    We jointly utilize the global (A g ( l ) A^{(l)}_gAg(l))Wadatou( { A i ( l ) } i = 1 K \{A^{(l)}_i \}^K_{i=1}{ Ai(l)}i=1K)关系(即A ( l ) = { A g ( l ) , A 1 ( l ) , … , AK ( l ) } A^{(l)} = \{A^{(l)}_g, A^ {(l)}_1,…, A^{(l)}_K\}A(l)={ Ag(l),A1(l)AK(l)} ) propagates information across the instance-level graph to update the sample representation. In this way, we can more fully explore the relationships between samples and learn richer sample representations. in thellAt layer l , we use the updated sample representationV ( l ) V^{(l)}V( l ) Construct a new adjacency matrixA g (l) A^{(l)}_gAg(l)Sum A i ( l ) A^{(l)}_iAi(l), as shown below:
    Insert image description here
    Inspired by the success of TRPN in the few-shot classification task, we use the following matrix to mask the adjacency matrix:
    Insert image description here
    where mmm andnnn isS ∪ QS \cup QSSamples in Q , ym y_mymis the sample mmm label. This ensures that for two samples of different categories, the higher the feature similarity, the lower the commonality in the message passing process. For two samples of the same category, the results are exactly opposite.
    Insert image description here

  • Compression module: Class-level visual knowledge learning
    To obtain class-level knowledge representation, we compress the instance-level graph to generate a class-level graph, where nodes represent the visual knowledge of the class. For example, we compress the nodes in the instance-level graph into 5 clusters/nodes to obtain visual knowledge of classes in 5-way classification tasks. Specifically, we first use ground truth to supervise the generation of the allocation matrix, and then compress the samples according to the allocation matrix to obtain the class-level knowledge representation V c ∈ R r 1 × d V_c \in R^{r_1 \times d}VcRr1× d , wherer 1 r_1r1Indicates the number of classes in each episode. For the sake of simplicity, this article will refer to V ( L ) V^{(L)}V( L ) andA g ( L ) A^{(L)}_gAg(L)Substitute into the standard graph neural network to calculate the distribution matrix P ∈ R r × r 1 P \in R^{r \times r_1}PRr×r1:
    Insert image description here
    其中, W ∈ R d × r 1 W \in R^{d \times r_1} WRd×r1Represents a trainable weight matrix to which softmax operations are applied row-wise. Distribution matrix PPEach element P uv P_{uv}in PPuvRepresents the node uu in the original graphu is assigned to node vvin the class-level graphThe probability of v . When generating the distribution matrixPPAfter P , we generate the initial class-level knowledge representation using the following equation:
    Insert image description here
    whereTTT represents the transpose operation. In a class-level graph, each node feature can be viewed as a weighted sum of node features with the same label in the instance-level graph. Through this approach, we obtain a class-level visual knowledge representation, which will help model the relationships between different classes in the calibration module.

  • Calibration module: Class-level messaging with multi-modal knowledge
    Since class-word embeddings can provide information that may not be contained in visual content, we combine it with the generated class-level visual knowledge to build multi-modal knowledge representation. Specifically, we first leverage GloVe (pre-trained on a large text corpus with self-supervised constraints) to obtain d 1 of class labelsd1Dimensional semantic embedding. This article uses the Common Crawl version of GloVe, which is trained on 840B tokens. After obtaining the iiWord embedding of type i ei ∈ R d 1 e_i \in R^{d_1}eiRd1After that, use the mapping network g: R d 1 → R dg: R^{d_1} \rightarrow R^dg:Rd1Rd maps it to a semantic space with the same dimensions as the visual knowledge representation, that is,zi = g ( ei ) ∈ R d z_i = g(e_i) \in R^dzi=g ( ei)Rd . Finally, we obtain the following multimodal class representation:
    Insert image description here
    whereZ ∈ R r 1 × d Z \in R^ {r_1 \times d}ZRr1× d is the semantic word embedding matrix. This results in a richer representation of class-level knowledge.
    Adjacency matrix of class-level graph (A c A_cAc) represents the relationship represented by the class, and its value represents the connectivity strength of the class pair. In this article, we calculate the adjacency matrix A c A_c using the following formulaAcand new class-level knowledge representation vc ′ ′ v''_cvc:
    Insert image description here
    其中 w ′ ∈ R 2 d × 2 d w' \in R^{2d \times 2d} wR2 d × 2 d is a trainable weight matrix. In order for each sample to contain the class knowledge learned in (7), we utilize the assignment matrix to map the class knowledge back to the instance-level graph as follows where
    Insert image description here
    Vr ∈ R r × 2 d V_r \in R^{r \times 2d}VrRr × 2 d represents the optimized features. Finally, V r V_ris connected in seriesVrWith V ( L ) V^{(L)}V( L ) combined to generate sample representationV f V_fVf

3.3 Reasoning

To infer the class label of the query sample, we use V f V_fVfCalculate the corresponding adjacency matrix A f A_fAfAs follows:
Insert image description here
where V f ; m V_{f;m}Vf;m, V f ; n V_{f;n} Vf;nrespectively represent mmm samples andnnthn samples. fl : R 3 d → R 1 f_l: R^{3d} \rightarrow R^1fl:R3d _R1 is a mapping function. For each query example, we leverage the class label of the supporting example to predict its label:
Insert image description here
where one-hot denotes one-hot encoder

3.4 Loss function

The overall framework of the proposed ECKPN can be optimized in an end-to-end form through the following loss function:
Insert image description here
where λ0, λ1, and λ2 are hyperparameters, and the experimental settings are 1.0, 0.5, and 1.0. L 0 \mathcal{L}_0L0L 1 \mathcal{L}_1L1L 2 \mathcal{L}_2L2They are adjacency loss, assignment loss and classification loss. specifically,

  • Adjacency Loss
    for each graph network layer l = { 1 , ... , L } l = \{1 , ... , L \}l={ 1, L } In the comparison module, we have multiple adjacency matricesA g ( L ) A^{(L)}_gAg(L) { A i ( L ) } i = 1 K \{A^{(L)}_i \}^K_{i=1} { Ai(L)}i=1KUsed to support messaging between samples and query samples. Furthermore, in the above, we have the adjacency matrix A f A_f for query inferenceAf. To ensure that these adjacency matrices capture the correct sample relationships, we use the following loss function
    Insert image description here
    Insert image description here
    where mmm andnnn represents a node in the graph.

  • Assignment Loss is to ensure the assignment matrix PP
    calculated in the compression moduleP is able to correctly cluster samples with the same label, we utilize the following cross-entropy loss function:
    Insert image description here

  • Classification Loss
    In order to constrain the proposed ECKPN to predict the correct query label, we use the following loss function
    Insert image description here
    where L ce \mathcal{L}_{ce}LceRepresents the cross-entropy loss function.

4. Some experimental results

  • SOTA method performance comparison
    1) miniimagenet library
    Insert image description here
    2) tieredimagenet library
    Insert image description here
    3) cub and cifar-fs library
    Insert image description here

  • Semi-supervised small sample classification performance
    Insert image description here

  • Parameter analysis
    Insert image description here

  • ablation experiment
    Insert image description here

5 Conclusion

1) There are two settings for graph-based meta-learning methods: direct inference and inductive . The direct inference method describes the relationship between the samples of the support set and the query set for joint prediction, and its performance is better than the inductive method. The inductive method can only learn the network based on the relationship between the support set and predict each query sample. Perform separate classification
2) This article belongs to the direct push method
3) This article proposes an end-to-end few-shot learning architecture based on graphs for the first time

Guess you like

Origin blog.csdn.net/weixin_43994864/article/details/123345574