Paper Reading (83): MuRCL: Multi-instance Reinforcement Contrastive Learning for Whole Slide Image (Medical Image)

1 Overview

1.1 Topics

2022Multi-instance reinforcement contrastive learning for whole slide image classification (MuRCL: Multi-instance reinforcement contrastive learning for whole slide image classification)

1.2 Summary

Multiple-instance learning (MIL) is widely used in automatic Whole slide image ( WSI ) analysis, and its processing strategies can be divided into:

  1. instance feature extraction;
  2. feature aggregation.

However, due to the weak supervision of slide-level labels , the training process of MIL models usually exhibits severe overfitting . In this case, it is crucial to discover more information from the limited slide-level annotated data.

Different from existing methods, this paper focuses on exploring the potential relationship between different instances (blocks) , rather than improving the extraction of instance features to improve the generalization ability of the model. Specifically, MuRCL deals with the problem from the following perspectives:

  1. A self-supervised manager is trained and then fine-tuned based on WSI slide-level labels. This process is called contrastive learning ( Contrastive learning, CL ), which builds a positive/negative discriminative feature set based on the same block-level feature package in WSI ;
  2. To accelerate CL training, a reinforcement learning- based agent is designed to gradually update the selection of discriminative feature sets according to the online reward of slide-level feature aggregation. The labeled WSI data is then used to update the model and learned features and obtain the final WSI classification.

Experiments are performed on three public WSI classification datasets, including Camelyon16, TCGA-Lung, and TCGA-Kidney. The experimental results verify the performance of MuRCL, especially in the TCGA-Lung dataset.

Figure 1 shows the difference between MuRCL and general MIL.

Figure 1: Comparison of MuRCL and general MIL methods: (a) the general method is based on image-level labels to achieve block extraction, block selection, and block aggregation; (b) MuRCL utilizes the intrinsic relationship of different patches, by maximizing the same The consistency between the two discriminative feature sets of WSI is trained, and then the prediction of WSI is fine-tuned.

1.3 Code

Torch:https://github.com/wwu98934/MuRCL

1.4 References

@article{
    
    Zhu:2022:113,
author		=	{
    
    Zhong Hang Zhu and Le Quan Yu and Wei Wu and Rong Shan Yu and De Fu Zhang and Lian Sheng Wang},
title		=	{
    
    {
    
    MuRCL}: {
    
    M}ulti-instance reinforcement contrastive learning for whole slide image classification},
journal		=	{
    
    {
    
    IEEE} Transactions on Medical Imaging},
pages		=	{
    
    1--13}
year		=	{
    
    2022},
doi			=	{
    
    10.1109/TMI.2022.3227066}
}

2 methods


  Figure 2: MuRCL's self-supervised learning process: (a) MuRCL : Given a WSI input feature bag ( WSI-Fbag ) xxx , the outputs of the two RL-MIL branches are positive pairs for contrastive losses; (b)RL-MIL: a reward-oriented agentR \boldsymbol{R}R was used fromxxx selectsthe discriminative feature set(WSI-Fset)x ~ \tilde{x}x~ x ~ \tilde{x} x~ is first randomly initialized, then byR \boldsymbol{R}R update. x ~ \tilde{x}x~ passed to the MIL aggregator to generatefeature embeddingsvvv , and through the mapping headf ( ⋅ ) f(\cdot)f ( ) outputsppp , which will be used to calculate the contrastive loss; (c \text{c}c )RL selection:st s_tstused from s_t^k by stkstkThe indexed features are selected from each cluster of the formed association graph

Figure 2(a) presents the general framework of MuRCL, which constructs positive/negative pairs for contrastive learning by proxy and selects two independent discriminative sets from the input feature bag . Then, the MIL aggregator M ( ⋅ ) M(\cdot)M ( ) and projection headf ( ⋅ ) f(\cdot)f ( ) uses a contrastive loss to maximize the agreement between positive discriminative sets. Each branch of MuRCL is an RL-MIL.

Figure 2(b) shows the sequence decision process of RL-MIL. Given an input bag, RL-MIL iteratively generates a sequence of discriminative sets and outputs a sequence of feature vectors using a MIL aggregator and projection head. Specifically, at each step, the agent (the discriminative set network) determines a set of feature indices. Then, the next discriminant set is constructed by combining the index features of each packet cluster, as shown in 2 ( c \text{c}c ). Finally, the self-supervised pre-trained modelis fine-tuned with WSI labelsto obtain representations for prediction.

2.1 Multi-instance contrastive learning

Multi-instance contrastive learning uses WSI-Fbags as input, where each bag contains block-level embeddings processed by ResNet18 trained on ImageNet . A key step in CL is to construct logical positive/negative pairs (i.e., semantically similar/dissimilar instances) for training . Different from existing image enhancement-based strategies, we sample different WSI discriminative sets (smaller WSI-Fsets) from each WSI-Fbag and construct set-based positive/negative pairs for CL training.

Each WSI-Fset is a combination of subsets obtained from multiple feature clusters of WSI-Fbag. In particular, given WSI0Fbag xxx , first divide it intoKKK clustersC k ( k ∈ [ 1 , 2 , … , K ] ) C_k(k\in[1,2,\dots,K])Ck(k[1,2,,K ]) . fromC k to C_kCkSampling to a subset, WSI-Fset x ~ \tilde{x}x~ is the concatenation of all these subsets. The sampling rate of each cluster remains the same, sox ~ \tilde{x}x~ has a constant number of instances embedded.

Subsequently, x ~ \tilde{x}x~ is passed to the MIL aggregatorM ( ⋅ ) M(\cdot)M ( ) and projection headf ( ⋅ ) f(\cdot)f ( ) to getWSI level feature embeddingppp . It is worth noting thatthere are multiple sampling strategies here, so different x ~ \tilde{x} can be obtainedx~ and the corresponding embedding. differentx ~ \tilde{x}x~ can be seen as different perspectives of the same WSI, which can be used to construct positive pairs in CL.

{ x ~ n } n = 1 N \{\tilde{x}_n\}_{n=1}^N{ x~n}n=1Nmeans NNA group of N WSI-Fsets, wherex ~ i \tilde{x}_ix~iand x ~ j \tilde{x}_jx~jsampled from the same xxx , others are sampled from different WSI-Fbags. Then,the CL lossis calculated as:
L i , j = − log ⁡ exp ⁡ ( sim ( pi , pj ) / τ ) ∑ n = 1 N 1 ( n ≠ i ) exp ⁡ ( sim ( pi , pn ) / τ ) , (1) \tag{1} L_{i,j}=-\log\frac{\exp(sim(p_i,p_j)/\tau)}{\sum_{n=1}^N1(n\neq i)\exp(sim(p_i,p_n)/\tau)},Li,j=logn=1N1(n=i)exp(sim(pi,pn) / t )exp(sim(pi,pj) / t ),( 1 ) whereτ \tauτ is the temperature parameter,sim ( ⋅ , ⋅ ) sim(\cdot,\cdot)sim(,) represents the cosine similarity between two vectors,1 ( n ≠ i ) ∈ { 0 , 1 } 1(n\neq i)\in\{0,1\}1(n=i){ 0,1 } is an indicator function, i.e. 1 1if the condition is met1 , otherwise0 00 . In this paper , NT-Xentis usedto maximize the similarity between positive pairs and minimize the similarity between negative pairs. In this case, the MIL aggregator will be able to learn aggregated knowledge for accurate classification.

2.2 Construction of Discrimination Set Based on RL

As mentioned before, an important step in MuRCL is how to construct the WSI discrimination set ( WSI-Fset ). Therefore, we propose a novel reinforcement learning-based strategy, denoted RL-MIL, which builds WSI-Fset based on WSI-Fbag, and one of the WSI agents RRR (recurrent neural networks) are trained with reinforcement learning. As shown in Figure 2(b), the construction process of WSI-Fset is a sequential decision. At each step, the MILaggregatorM ( ⋅ ) M(\cdot)M ( ) andmapping headf ( ⋅ ) f(\cdot)f ( ) uses WSI-Fset as input to obtain the corresponding semantic predictionppp . At the same time,RRR uses the input feature vectorvvv generate another WSI-Fset proposalsss

In particular, for input xxx ,RL-MILiteratively generates asequence{ x ~ 0 , … , x ~ t , … , } \{\tilde{x}_0,\dots,\tilde{x}_t,\dots,\}{ x~0,,x~t,,} . atttAt t iterations,M ( ⋅ ) M(\cdot)M ( ) receives the currentx ~ t \tilde{x}_tx~t, and output feature vector vt v_tvt, and by f ( ⋅ ) f(\cdot)f ( ) Getslide-level embeddingpt p_tpt. Then, pt p_tptis passed to Equation 1 . Meanwhile, WSI-Fset Proposal Proxy RRR willvt v_tvtAs input to determine the next x ~ t + 1 \tilde{x}_{t+1}x~t+1The behavior of the selected feature index st + 1 s_{t+1}st+1. Then, x ~ t + 1 \tilde{x}_{t+1}x~t+1from xxSelect in x , as shown in Figure 2 ( c \text{c}c ). So,iterate one round.

M ( ⋅ ) M(\cdot) M ( ) using ABMIL and CLAM,f ( ⋅ ) f(\cdot)f ( ) and proposal agentRRR both use recurrent neural networks, so that they can each maintain a hidden stateht − 1 R h^R_{t -1}ht1Rand htfh^f_thtfto explore all previously entered information. Note that the strategy in this article is not an RL framework, but uses the optimization strategy in RL to optimize its own method . In MuRCL, RL is used to assist MIL discriminative set construction. During the ensemble construction phase, the agent scans the WSI multiple times to locate discriminative features, calculate rewards and update the agent for the next decision. Therefore, this paper draws on the idea of ​​reinforcement learning.

2.2.1 RL selection

In each step of RL discriminative set construction, the feature index selects WSI-Fset from WSI-Fbag for subsequent steps. To speed up proxy generation of a spatially linked WSI-Fset proposal , we base on xxThe cluster labels of x rank its features from, i.e.features with the same cluster labels will be assigned adjacent indices. Then, for each clusterC k C_kCkfeatures in , rearrange them along the coordinates of their corresponding blocks. The reordered WSI-Fbag is called an association graph , as shown in Figure 2 ( c \text{c}c ). Then, according to the proxyRRR predicted actionsss synthesizes a WSI-Fset from the association graph, where we setthe action sss is formulated as a set of feature indices for the rearranged clusters. In particular, thettThe feature index vector of t step st ∈ RK ∗ 1 s_t\in\mathbb{R}^{K*1}stRK 1 isRRR obtained in the previous step, its elementstk s_t^kstkIndicates the kkthFeature indices of k rearranged clusters. Therefore, for clusterC k C_kCk, from stk s_t^kstkA feature begins to sample a sequence whose length is the sampling rate multiplied by the feature dimension, and then splicing the sequences of different clusters to obtain x ~ t \tilde{x}_tx~t

2.2.2 Rewards

WSI-Fset Proposal Proxy RRR is trained usingthe policy gradient method. During the training phase, the reward function is used to controlthe RRThe optimization direction of R. In this paper, the similarity between two positive pairs of WSI-Fset is used as a reward to guide the agentRRR positioning information features. In particular, atttt step,the reward functionis:
ri , j ; t = sim ( pi ; t − 1 , pj ; t − 1 ) − sim ( pi ; t − pj ; t ) , (2) \tag{2} r_{i ,j;t}=sim(p_{i;t-1},p_{j;t-1})-sim(p_{i;t}-p_{j;t}),ri,j;t=sim(pi;t1,pj;t1)sim(pi;tpj;t),( 2 ) pip_ipiand pj p_jpjfrom x ~ i \tilde{x}_i respectivelyx~iand x ~ j \tilde{x}_jx~j. By positively cosine distance between WSI-Fset, this will force the MIL model to focus on the latent pooled knowledge by minimizing CL.

2.2.3 Discriminative set mixing

To introduce more perturbations during MIL aggregator training, this paper uses an efficient feature pooling strategy called set -mixup to increase the diversity of WSI-Fset. For WSI-Fset in the training batch, mix x ~ l \tilde{x}_lx~lto x ~ q \tilde{x}_{q}x~qGenerate an enhanced representation in x ‾ q \overline{x}_qxq
x ‾ q = λ x ~ q + ( 1 − λ ) x ~ l , (3) \tag{3} \overline{x}_q=\lambda\tilde{x}_q+(1-\lambda)\tilde{x}_l, xq=lx~q+(1l )x~l,( 3 ) whereλ \lambdaλ is a function fromU ( α , 1.0 ) U(\alpha,1.0)U ( a ,1.0 ) sampling in the distribution, this paper setsα = 0.9 \alpha=0.9a=0.9 . This mixture is used to enhance semantic concept learning.

2.3 RL-MIL training strategy

As shown in Figure 2(a) , the contrastive learning framework has two branches, and in the two branches M ( ⋅ ) M(\cdot)M() f ( ⋅ ) f(\cdot) f ( ) , andRRThe parameters of R are shared. In a training batch, first two WSI-Fsets are randomly selected from WSI-Fbag aspositive pairs for initialization. Then, the agent that will be trained with the RL policy is being reconstructed. For clarity, this paper uses the following table( ⋅ ) i (\cdot)_i()iand ( ⋅ ) j (\cdot)_j()jIndicates two branches , { ( ⋅ ) t } t = 0 T \{(\cdot)_t\}_{t=0}^T{()t}t=0TRepresents the time series generated by each branch. Here T = 5 T=5T=5 , meaningthe RNN runs five times on each training branch. At the beginning, positive pairs are represented asx ~ i ; 0 \tilde{x}_{i;0}x~i;0and x ~ j ; 0 \tilde{x}_{j;0}x~j;0, and by M ( ⋅ ) M(\cdot)M ( ) generates two different eigenvectorsvi ; 0 v_{i;0}vi;0and vj ; 0 v_{j;0}vj;0. Meanwhile, the WSI-level feature embedding pi; 0 p_{i;0}pi;0and pj ; 0 p_{j;0}pj;0is calculated and used to calculate the CL loss L 0 = L i , j ; 0 ( pi ; 0 , pj ; 0 ) L_0=L_{i,j;0}(p_{i;0},p_{j;0})L0=Li,j;0(pi;0,pj;0) , where WSI-Fsets from different WSI-Fbags are considered as negative pairs. The first iteration is performed next, by passingvi ; 0 v_{i;0}vi;0and vj ; 0 v_{j;0}vj;0as RRThe initial state of R is used to generate new positive pairs, and for the new positive pairs, the same calculation starts again. In five iterations,f ( ⋅ ) f(\cdot)f ( ) is processed synchronously, and the output of its two branches{ pi ; t , pj ; 0 } t = 10 5 \{p_{i;t},p_{j;0}\}_{t=10} ^5{ pi;t,pj;0}t=105Used to calculate CL loss: L t = ∑ t = 0 5 L i , j ( pi ; t , pj ; t ) L_t=\sum_{t=0}^5L_{i,j}(p_{i;t} ,p_{j;t})Lt=t=05Li,j(pi;t,pj;t) . In addition,RRR needs to maximize the reward∑ t = 1 5 γ t − 1 ri , j ; t ( pi ; t , pj ; t ) \sum_{t=1}^5\gamma^{t-1}r_{i,j ;t}(p_{i;t},p_{j;t})t=15ct1ri,j;t(pi;t,pj;t) , whereγ = 0.1 \gamma=0.1c=0.1 ; using the latter two sampling strategies fromNN2 N 2Nare generated in N WSI-Fbags2 N WSI-Fsets. The process is asAlgorithm 1.

The training process consists of three stages:

  1. Randomly sample WSI-Fset to train M ( ⋅ ) M(\cdot)M ( ) andf ( ⋅ ) f(\cdot)f ( ) , this stage is used to ensure that the model can handle sequences of arbitrary size;
  2. Fixed M ( ⋅ ) M(\cdot)M ( ) andf ( ⋅ ) f(\cdot)f ( ) , randomly initializeRRR and train;
  3. Fixed RRR , fine-tuningM ( ⋅ ) M(\cdot)M ( ) andf ( ⋅ ) f(\cdot)f()

2.4 Fine-tuning and inference

The MIL contrastive learning framework can deeply explore the semantic relationship between different blocks represented by slide-level WSI. For the final slide-level predictions, the labeled WSI will be used to fine-tune the framework. At this stage, f ( ⋅ ) f(\cdot)The output dimension of f ( ) is changed from 128 to the number of categories. Fine-tuning still includes three stages, the difference is thatf ( ⋅ ) f(\cdot)f ( ) will be followed by softmax to obtaina confidence score, and then the increment of the score is used asRRThe reward of R , namelyr ^ t = p ^ t − p ^ t − 1 \hat{r}_t=\hat{p}_t-\hat{p}_{t-1}r^t=p^tp^t1, where ( ⋅ ^ ) (\hat{\cdot})(^ )represents the corresponding variable in fine-tuning,p ^ \hat{p}p^is the softmax predicted probability.

The testing process of MuRCL is consistent with the fine-tuning process:

  1. Given a test WSI-Fbag, randomly sample WSI-Fset, M ( ⋅ ) M(\cdot)M ( ) provides the initial state;
  2. R R R determines WSI-Fset,M ( ⋅ ) M(\cdot)M ( ) andf ( ⋅ ) f(\cdot)f ( ) processing. In this stage, the agent iteratively generates the state vector, and the last output of the agent is used as the WSI-Fset proposal, and then outputs the classification prediction;

Contrastive loss can shorten the distance between features of similar categories and increase the distance between features of different categories.

3 experiments

3.1 Dataset

  1. Camelyon16 : Breast cancer detection data, containing 270 training WSIs and 129 testing WSIs. The purpose of the task is whether it is cancer, or positioning. Contains 2.7 million blocks after preprocessing, with an average of 6881 per packet;
  2. TCGA-Lung: Contains two sub-items, TCGA-LUSC and TCGA-LUAD , with a total of 1041 diagnostic images, 529LUAD and 512LUSC, for WSI subtype classification and survival analysis. After preprocessing, an average of 11540 per package;
  3. TCGA-Kidney : Contains TCGA-KICH , TCGA-KIRC , and TCGA-KIRP three sub-projects, with a total of 734 WSIs, including KICH92, KIRC411, and KIRP231, which is also applicable to multi-category.
  4. TCGA-Esca: Contains two subcategories, totaling 156 WSI, of which 90 are squamous cell carcinoma and 66 are adenocarcinoma.

Preprocessing :
All experiments use 20 × 20\times20 × magnitude of WSI, each WSI is cropped to256 × 256 256\times256256×Multiple blocks of 256 , organization area below35 % 35\%35% of blocks will be discarded. For Camelyon16, training set20% 20\%20% as the validation set. For TCGA, train:validation:test=3:1:1.

3.2 Implementation Details and Evaluation Metrics

  1. Each block is embedded into a 512-dimensional feature vector through a pre-trained model, of which Camelyon16 and TCGA-Lung use SimCLR and ResNet18, and the rest only use ResNet18;
  2. The WSI feature package is first clustered into 10 clusters using Kmeans, each cluster C k C_kCkThe sample rate is set to 1024/u 1024/u1024/ u , in whichuuu is the number of bag features;
  3. Batch size N = 128 N=128N=128 , temperature parameterτ = 1 \tau=1t=1
  4. The first stage of training uses Adam, where M ( ⋅ ) M(\cdot)The learning rate of M ( ) is set to 1e-4,f ( ⋅ ) f(\cdot)f ( ) is 1e-5, and the weight decay is set to 1e-5;
  5. Second stage proxy RRThe initial learning rate of R is set to 1e-5;
  6. The third stage M ( ⋅ ) M(\cdot)M ( ) andf ( ⋅ ) f(\cdot)f ( ) uses Adam joint optimization, the learning rate is set to 5e-5 and 1e-5 respectively, and the weight decay remains unchanged;
  7. The training batches of the three stages are 100, 30, and 100 respectively;
  8. Evaluation indicators use ACC, AUC, and F1.

Guess you like

Origin blog.csdn.net/weixin_44575152/article/details/128288963