【CVPR2023】Conflict-Based Cross-View Consistency for Semi-Supervised Semantic Segmentation

Conflict-Based Cross-View Consistency for Semi-Supervised Semantic Segmentation, CVPR2023

Paper: https://arxiv.org/abs/2303.01276

Code: https://github.com/xiaoyao3302/CCVC/

Summary

Semi-supervised semantic segmentation (SSS) can reduce the need for large-scale fully annotated training data. Existing methods often suffer from confirmation bias when dealing with pseudo-labels, which can be mitigated by a joint training framework. Current SSS methods based on joint training rely on hand-crafted perturbations to prevent different subnetworks from collapsing, but artificial perturbations are difficult to obtain optimal solutions. In this paper, we propose a novel conflict-based cross-view consistency (CCVC) method based on a joint training framework of two branches, aiming to force two subnetworks to learn informative features from unrelated views. We first propose a novel cross-view consistency (CVC) strategy, which encourages two subnetworks to learn different features from the same input by introducing a feature difference loss, while these discontinuous features are expected to generate consistent prediction scores for the input. The CVC strategy helps prevent both subnets from falling into a crash. In addition, a conflict-based pseudo-labeling (CPL) method is further proposed to ensure that the model will learn more useful information from conflict predictions to ensure the stability of the training process.

introduction

Fully supervised semantic segmentation requires a lot of effort to collect accurate annotation data. Semi-supervised learning can achieve semantic segmentation using a small amount of labeled data and a large amount of unlabeled data. But how to use unlabeled data to assist labeled data for model training is a key issue.

The method of using pseudo-labels may be affected by confirmation bias, which will cause performance degradation due to unstable training. Approaches based on consistency regularization show better performance, but most rely on predictions from weakly perturbed inputs to generate pseudo-labels, which are then used as supervision for predictions from strongly perturbed inputs. It is also affected by confirmation bias .

Co-training enables different subnetworks to infer the same instance from different views and transfer knowledge learned from one view to another through pseudo-labeling. In particular, co-training relies on multi-view references to increase the perception of the model and thus improve the reliability of the generated pseudo-labels. The key is how to prevent different subnetworks from collapsing each other so that the model can make correct predictions based on inputs from different views. However, the manual perturbation used in most SSS methods does not guarantee learning heterogeneous features, effectively preventing subnetworks from falling into collapse.

Facing the above problems, this paper proposes a new conflict-based cross-view consistency (CCVC) strategy for SSS, which ensures that the two subnetworks in the model can learn different Learn reliable predictions from the views of , for joint training , which further enables each sub-network to make reliable and meaningful predictions.

  • Firstly, a cross-view consistency (CVC) method with difference loss is proposed to minimize the similarity between the features extracted by two subnetworks, thus encouraging them to extract different features, thus preventing the two subnetworks from collapsing each other.
  • Then, knowledge learned from one subnetwork is transferred to another using cross-pseudo-labeling to improve network awareness to correctly reason about the same input from different views, resulting in more reliable predictions.
  • However, the difference loss may introduce too strong perturbation to the model, such that the features extracted by the subnetwork may contain less meaningful information for prediction, resulting in inconsistent and unreliable predictions from the two subnetworks. This will lead to the confirmation bias problem, which will harm the joint training of subnetworks. To address this problem, a collision-based pseudo-label (CPL) method is further proposed to encourage the pseudo-labels generated by each subnetwork's conflicting predictions to provide stronger supervision on each other's predictions to enforce consistent predictions from both subnetworks. , to preserve the forecast and the reliability of the forecast. In this way, it is expected to reduce the impact of confirmation bias and make the training process more stable.

As shown in Figure 1, the similarity scores between the features extracted from the two subnetworks of the cross-consistency regularization (CCR) model are maintained at a high level, which indicates that the inference perspective of CCR is somewhat related. In contrast, the CVC approach ensures that the reanalyzed views are sufficiently different to produce more reliable predictions.

figure 1. The cosine similarity values ​​between the features extracted by the two subnetworks of the traditional cross-consistency regularization (CCR) method and our CVC method are compared. The prediction accuracy of the two methods measured with mIoU is also compared. CVC methods can prevent two subnetworks from collapsing each other and infer inputs from irrelevant views, while CCR cannot guarantee that the introduced views are different. It is shown that CVC can increase the perception of the model and thus produce more reliable predictions.

 

related work

semantic segmentation

Semi-Supervised Semantic Segmentation

collaborative training

method

CCVC(conflict-based cross-view consistency

Cross-view consistency

Cross-View Consistency (CVC) approach. A two-branch network based on joint training is used, where the two subnetworks (Ψ1 and Ψ2), have similar architectures but do not share parameters. Divide each subnetwork into feature extractors and classifiers. The goal is to enable the two subnetworks to reason about inputs from different views, so the extracted features should be different. Therefore, the cosine similarity between the features extracted by the two feature extractors is minimized.

Note that the factor of 1 is to ensure that the value of the loss is always non-negative. The two subnetworks are encouraged to output features that have no common relationship, thereby forcing the two subnetworks to learn to reason about inputs from two unrelated views.

Most SSS methods use ResNet pre-trained on ImageNet as the backbone of DeepLabv3+, and only fine-tune the backbone with a small learning rate, which makes the feature difference maximization operation in this paper difficult to achieve. To address this issue, this paper achieves network heterogeneity by using convolutional layers to map the extracted features to another feature space. The difference loss is rewritten as:

Both labeled and unlabeled data are supervised with difference loss. The total loss of difference is: 

Labeled data uses ground-truth labels as supervision, and both subnetworks are supervised. The supervised loss for labeled data is:

Pseudo-labeling of unlabeled data enables each subnetwork to learn semantic information from the other. Apply cross-entropy loss to fine-tune the model. The two subnetworks supervise each other as pseudo-labels. The loss for unlabeled data is (cross-consistency loss):

 

The total loss of the entire network: the weighted sum of supervision loss, consistency loss, and difference loss. 

 

 Conflict-based pseudo-labelling

Using the cross-view consistency (CVC) method, two subnetworks will learn semantic information from different views. However, if the feature difference loss introduces too strong perturbation on the model, the training will not be stable enough. Therefore, it is difficult to guarantee that the two subnetworks can learn useful semantic information from each other, which may further affect the reliability of predictions.

Therefore, the paper proposes a conflict-based pseudo-labeling (CPL) method, which enables the two subnetworks to learn more semantic information from conflicting predictions to make consistent predictions, thereby ensuring that the two subnetworks can generate the same Reliable predictions and further stabilization of training. Use the binary value δ to define whether the predictions conflict.


The goal is to encourage the model to learn more semantic information from these conflicting predictions. Therefore, when these predictions are used to generate pseudo-labels for fine-tuning the model, higher weight ωc is assigned to the cross-entropy loss supervised by these pseudo-labels.

However, during training, training may also suffer from confirmation bias, as some pseudo-labels may be wrong. The paper further divides conflict predictions into two categories, conflicting and trustworthy (CC) predictions, conflicting but untrustworthy (CU) predictions , and only assigns ωc to pseudo-labels generated by CC predictions.

Define CC predictions using binary values \delta ^{cc}_{mn,i}. Use \delta ^{e}_{mn,i}the union representing CU predictions and conflict-free predictions

The pseudo-labels generated by CU predictions are still used to fine-tune the model with normal weights instead of discarding them directly, since these CU predictions can also contain latent information about inter-class relationships. The consistency loss is rewritten as:

in, 

 

 

The CCVC method can effectively encourage two subnetworks to reason about the same input from different perspectives, and the knowledge transfer between the two subnetworks can increase the perception of each subnetwork, thereby improving the reliability of predictions.

In the inference phase, only one branch of the network is needed to generate predictions.

experiment

 

 

key code

CCVC_no_aug.py

# https://github.com/xiaoyao3302/CCVC/blob/master/CCVC_no_aug.py

args = parser.parse_args()

args.world_size = args.gpus * args.nodes
args.ddp = True if args.gpus > 1 else False

def main(gpu, ngpus_per_node, cfg, args):

    args.local_rank = gpu

    if args.local_rank <= 0:
        os.makedirs(args.save_path, exist_ok=True)
        
    logger = init_log_save(args.save_path, 'global', logging.INFO)
    logger.propagate = 0

    if args.local_rank <= 0:
        tb_dir = args.save_path
        tb = SummaryWriter(log_dir=tb_dir)

    if args.ddp:
        dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.world_size)
        # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=args.world_size)

    if args.local_rank <= 0:
        logger.info('{}\n'.format(pprint.pformat(cfg)))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = True

    model = Discrepancy_DeepLabV3Plus(args, cfg)
    if args.local_rank <= 0:
        logger.info('Total params: {:.1f}M\n'.format(count_params(model)))

    # pdb.set_trace()
    # TODO: check the parameter !!!

    optimizer = SGD([{'params': model.branch1.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone}, 
                     {'params': model.branch2.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone},
                     {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                      'lr': args.base_lr * args.lr_network}], lr=args.base_lr, momentum=0.9, weight_decay=1e-4)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=False)

    # # try to load saved model
    # try:   
    #     model.load_model(args.load_path)
    #     if args.local_rank <= 0:
    #         logger.info('load saved model')
    # except:
    #     if args.local_rank <= 0:
    #         logger.info('no saved model')

    # ---- #
    # loss #
    # ---- #
    # CE loss for labeled data
    criterion_l = nn.CrossEntropyLoss(reduction='mean', ignore_index=255).cuda(args.local_rank)
    
    # consistency loss for unlabeled data
    criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(args.local_rank)

    # ------- #
    # dataset #
    # ------- #
    trainset_u = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_u',
                             args.crop_size, args.unlabeled_id_path)
    trainset_l = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_l',
                             args.crop_size, args.labeled_id_path, nsample=len(trainset_u.ids))
    valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

    if args.ddp:
        trainsampler_l = torch.utils.data.distributed.DistributedSampler(trainset_l)
    else:
        trainsampler_l = None
    trainloader_l = DataLoader(trainset_l, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_l)

    if args.ddp:
        trainsampler_u = torch.utils.data.distributed.DistributedSampler(trainset_u)
    else:
        trainsampler_u = None
    
    trainloader_u = DataLoader(trainset_u, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_u)

    valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False)
    # if args.ddp:
    #     valsampler = torch.utils.data.distributed.DistributedSampler(valset)
    # else:
    #     valsampler = None
    
    # valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False, sampler=valsampler)

    total_iters = len(trainloader_u) * args.epochs
    previous_best = 0.0
    previous_best1 = 0.0
    previous_best2 = 0.0

    # can change with epochs, add SPL here
    conf_threshold = args.conf_threshold

    for epoch in range(args.epochs):
        if args.local_rank <= 0:
            logger.info('===========> Epoch: {:}, backbone1 LR: {:.4f}, backbone2 LR: {:.4f}, segmentation LR: {:.4f}'.format(
                epoch, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'], optimizer.param_groups[-1]['lr']))
            logger.info('===========> Epoch: {:}, Previous best of ave: {:.2f}, Previous best of branch1: {:.2f}, Previous best of branch2: {:.2f}'.format(
                epoch, previous_best, previous_best1, previous_best2))

        total_loss, total_loss_CE, total_loss_con, total_loss_dis = 0.0, 0.0, 0.0, 0.0
        total_mask_ratio = 0.0

        trainloader_l.sampler.set_epoch(epoch)
        trainloader_u.sampler.set_epoch(epoch)

        loader = zip(trainloader_l, trainloader_u)

        total_labeled = 0
        total_unlabeled = 0

        for i, ((labeled_img, labeled_img_mask), (unlabeled_img, ignore_img_mask, cutmix_box)) in enumerate(loader):

            labeled_img, labeled_img_mask = labeled_img.cuda(args.local_rank), labeled_img_mask.cuda(args.local_rank)
            unlabeled_img, ignore_img_mask, cutmix_box = unlabeled_img.cuda(args.local_rank), ignore_img_mask.cuda(args.local_rank), cutmix_box.cuda(args.local_rank)

            model.train()

            optimizer.zero_grad()

            dist.barrier()

            num_lb, num_ulb = labeled_img.shape[0], unlabeled_img.shape[0]

            total_labeled += num_lb
            total_unlabeled += num_ulb

            # =========================================================================================
            # labeled data: labeled_img, labeled_img_mask
            # =========================================================================================
            labeled_logits = model(labeled_img)

            # =========================================================================================
            # unlabeled data: unlabeled_img, ignore_img_mask, cutmix_box
            # =========================================================================================
            unlabeled_logits = model(unlabeled_img)

            # to count the confident predictions
            unlabeled_pred_confidence1 = unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0]
            unlabeled_pred_confidence2 = unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0]

            # =========================================================================================
            # calculate loss
            # =========================================================================================

            # -------------
            # labeled
            # -------------
            # CE loss
            labeled_pred1 = labeled_logits['pred1']
            labeled_pred2 = labeled_logits['pred2']

            loss_CE1 = criterion_l(labeled_pred1, labeled_img_mask)
            loss_CE2 = criterion_l(labeled_pred2, labeled_img_mask)

            loss_CE = (loss_CE1 + loss_CE2) / 2
            loss_CE = loss_CE * args.w_CE

            # -------------
            # unlabeled
            # -------------
            # consistency loss
            unlabeled_pred1 = unlabeled_logits['pred1']
            unlabeled_pred2 = unlabeled_logits['pred2']

            if args.mode_confident == 'normal':
                loss_con1 = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_con1 = torch.sum(loss_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_con2 = torch.sum(loss_con2) / torch.sum(ignore_img_mask != 255).item()
                
            elif args.mode_confident == 'soft':
                confident_pred1, confident_pred2, unconfident_pred1, unconfident_pred2 = soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_confident = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (confident_pred1 & (ignore_img_mask != 255))
                loss_con2_confident = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (confident_pred2 & (ignore_img_mask != 255))

                loss_con1_unconfident = args.w_unconfident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (unconfident_pred1 & (ignore_img_mask != 255))
                loss_con2_unconfident = args.w_unconfident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (unconfident_pred2 & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_confident) + torch.sum(loss_con1_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_confident) + torch.sum(loss_con2_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote':
                same_pred, different_pred = vote_label_selection(unlabeled_pred1, unlabeled_pred2)

                loss_con1_same = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))
                loss_con2_same = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))

                loss_con1_different = args.w_confident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_pred & (ignore_img_mask != 255))
                loss_con2_different = args.w_confident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_pred & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_threshold':
                different1_confident, different1_else, different2_confident, different2_else = vote_threshold_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_else = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different1_else & (ignore_img_mask != 255))
                loss_con2_else = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different2_else & (ignore_img_mask != 255))

                loss_con1_cc = args.w_confident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different1_confident & (ignore_img_mask != 255))
                loss_con2_cc = args.w_confident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different2_confident & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_else) + torch.sum(loss_con1_cc)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_else) + torch.sum(loss_con2_cc)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_soft':
                same_pred, different_confident_pred1, different_confident_pred2, different_unconfident_pred1, different_unconfident_pred2 = vote_soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_same = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))
                loss_con2_same = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))

                loss_con1_different_confident = args.w_confident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_confident_pred1 & (ignore_img_mask != 255))
                loss_con2_different_confident = args.w_confident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_confident_pred2 & (ignore_img_mask != 255))

                loss_con1_different_unconfident = args.w_unconfident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_unconfident_pred1 & (ignore_img_mask != 255))
                loss_con2_different_unconfident = args.w_unconfident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_unconfident_pred2 & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different_confident) + torch.sum(loss_con1_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different_confident) + torch.sum(loss_con2_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            else:
                loss_con1 = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_con1 = torch.sum(loss_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_con2 = torch.sum(loss_con2) / torch.sum(ignore_img_mask != 255).item()

            loss_con = (loss_con1 + loss_con2) / 2
            loss_con = loss_con * args.w_con

            # -------------
            # both
            # -------------
            # discrepancy loss
            cos_dis = nn.CosineSimilarity(dim=1, eps=1e-6)

            # labeled
            labeled_feature1 = labeled_logits['feature1']
            labeled_feature2 = labeled_logits['feature2']
            loss_dis_labeled1 = 1 + cos_dis(labeled_feature1.detach(), labeled_feature2).mean()
            loss_dis_labeled2 = 1 + cos_dis(labeled_feature2.detach(), labeled_feature1).mean()
            loss_dis_labeled = (loss_dis_labeled1 + loss_dis_labeled2) / 2

            # unlabeled
            unlabeled_feature1 = unlabeled_logits['feature1']
            unlabeled_feature2 = unlabeled_logits['feature2']
            loss_dis_unlabeled1 = 1 + cos_dis(unlabeled_feature1.detach(), unlabeled_feature2).mean()
            loss_dis_unlabeled2 = 1 + cos_dis(unlabeled_feature2.detach(), unlabeled_feature1).mean()
            loss_dis_unlabeled = (loss_dis_unlabeled1 + loss_dis_unlabeled2) / 2

            loss_dis = (loss_dis_labeled + loss_dis_unlabeled) / 2
            loss_dis = loss_dis * args.w_dis

            # -------------
            # total
            # -------------
            loss = loss_CE
            if args.use_con:
                loss = loss + loss_con
            if args.use_dis:
                loss = loss + loss_dis

            dist.barrier()

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_loss_CE += loss_CE.item()
            total_loss_con += loss_con.item()
            total_loss_dis += loss_dis.item()

            total_confident = (((unlabeled_pred_confidence1 >= 0.95) & (ignore_img_mask != 255)).sum().item() + ((unlabeled_pred_confidence2 >= 0.95) & (ignore_img_mask != 255)).sum().item()) / 2
            total_mask_ratio += total_confident / (ignore_img_mask != 255).sum().item()

            iters = epoch * len(trainloader_u) + i

            # update lr

            backbone_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            backbone_lr = backbone_lr * args.lr_backbone

            seg_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            seg_lr = seg_lr * args.lr_network
                
            optimizer.param_groups[0]["lr"] = backbone_lr
            optimizer.param_groups[1]["lr"] = backbone_lr
            for ii in range(2, len(optimizer.param_groups)):
                optimizer.param_groups[ii]['lr'] = seg_lr

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                tb.add_scalar('train_loss_total', total_loss / (i+1), iters)
                tb.add_scalar('train_loss_CE', total_loss_CE / (i+1), iters)
                tb.add_scalar('train_loss_con', total_loss_con / (i+1), iters)
                tb.add_scalar('train_loss_dis', total_loss_dis / (i+1), iters)

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                logger.info('Iters: {:}, Total loss: {:.3f}, Loss CE: {:.3f}, '
                            'Loss consistency: {:.3f}, Loss discrepancy: {:.3f}, Mask: {:.3f}'.format(
                    i, total_loss / (i+1), total_loss_CE / (i+1), total_loss_con / (i+1), total_loss_dis / (i+1), 
                    total_mask_ratio / (i+1)))

        if args.use_SPL:
            conf_threshold += 0.01
            if conf_threshold >= 0.95:
                conf_threshold = 0.95

        if cfg['dataset'] == 'cityscapes':
            eval_mode = 'center_crop' if epoch < args.epochs - 20 else 'sliding_window'
        else:
            eval_mode = 'original'
        
        dist.barrier()

        # test with different branches
        if args.local_rank <= 0:
            if epoch == 4:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=5)
            elif epoch == 9:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=10)
            elif epoch == 19:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=20)
            elif epoch == 39:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=40)
            else:
                evaluate_result = evaluate(args.local_rank, model, valloader, eval_mode, args, cfg)


            mIOU1 = evaluate_result['IOU1']
            mIOU2 = evaluate_result['IOU2']
            mIOU_ave = evaluate_result['IOU_ave']

            tb.add_scalar('meanIOU_branch1', mIOU1, epoch)
            tb.add_scalar('meanIOU_branch2', mIOU2, epoch)
            tb.add_scalar('meanIOU_ave', mIOU_ave, epoch)

            logger.info('***** Evaluation with branch 1 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU1))
            logger.info('***** Evaluation with branch 2 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU2))
            logger.info('***** Evaluation with two branches {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU_ave))

            if mIOU1 > previous_best1:
                if previous_best1 != 0:
                    os.remove(os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, previous_best1)))
                previous_best1 = mIOU1
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, mIOU1)))
            
            if mIOU2 > previous_best2:
                if previous_best2 != 0:
                    os.remove(os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, previous_best2)))
                previous_best2 = mIOU2
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, mIOU2)))

            if mIOU_ave > previous_best:
                if previous_best != 0:
                    os.remove(os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, previous_best)))
                previous_best = mIOU_ave
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, mIOU_ave)))

CCVC_aug.py

# https://github.com/xiaoyao3302/CCVC/blob/master/CCVC_aug.py

args = parser.parse_args()

args.world_size = args.gpus * args.nodes
args.ddp = True if args.gpus > 1 else False

def main(gpu, ngpus_per_node, cfg, args):

    args.local_rank = gpu

    if args.local_rank <= 0:
        os.makedirs(args.save_path, exist_ok=True)
        
    logger = init_log_save(args.save_path, 'global', logging.INFO)
    logger.propagate = 0

    if args.local_rank <= 0:
        tb_dir = args.save_path
        tb = SummaryWriter(log_dir=tb_dir)

    if args.ddp:
        dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.world_size)
        # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=args.world_size)

    if args.local_rank <= 0:
        logger.info('{}\n'.format(pprint.pformat(cfg)))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = True

    model = Discrepancy_DeepLabV3Plus(args, cfg)
    if args.local_rank <= 0:
        logger.info('Total params: {:.1f}M\n'.format(count_params(model)))

    # pdb.set_trace()
    # TODO: check the parameter !!!

    optimizer = SGD([{'params': model.branch1.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone}, 
                     {'params': model.branch2.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone},
                     {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                      'lr': args.base_lr * args.lr_network}], lr=args.base_lr, momentum=0.9, weight_decay=1e-4)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=False)

    # # try to load saved model
    # try:   
    #     model.load_model(args.load_path)
    #     if args.local_rank <= 0:
    #         logger.info('load saved model')
    # except:
    #     if args.local_rank <= 0:
    #         logger.info('no saved model')

    # ---- #
    # loss #
    # ---- #
    # CE loss for labeled data
    if args.mode_criterion == 'CE':
        criterion_l = nn.CrossEntropyLoss(reduction='mean', ignore_index=255).cuda(args.local_rank)
    else:
        criterion_l = ProbOhemCrossEntropy2d(**cfg['criterion']['kwargs']).cuda(args.local_rank)
    
    # consistency loss for unlabeled data
    criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(args.local_rank)

    # ------- #
    # dataset #
    # ------- #
    trainset_u = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_u',
                             args.crop_size, args.unlabeled_id_path)
    trainset_l = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_l',
                             args.crop_size, args.labeled_id_path, nsample=len(trainset_u.ids))
    valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

    if args.ddp:
        trainsampler_l = torch.utils.data.distributed.DistributedSampler(trainset_l)
    else:
        trainsampler_l = None
    trainloader_l = DataLoader(trainset_l, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_l)

    if args.ddp:
        trainsampler_u = torch.utils.data.distributed.DistributedSampler(trainset_u)
    else:
        trainsampler_u = None
    
    trainloader_u = DataLoader(trainset_u, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_u)

    valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False)
    # if args.ddp:
    #     valsampler = torch.utils.data.distributed.DistributedSampler(valset)
    # else:
    #     valsampler = None
    
    # valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False, sampler=valsampler)

    total_iters = len(trainloader_u) * args.epochs
    previous_best = 0.0
    previous_best1 = 0.0
    previous_best2 = 0.0

    # can change with epochs, add SPL here
    conf_threshold = args.conf_threshold

    for epoch in range(args.epochs):
        if args.local_rank <= 0:
            logger.info('===========> Epoch: {:}, backbone1 LR: {:.4f}, backbone2 LR: {:.4f}, segmentation LR: {:.4f}'.format(
                epoch, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'], optimizer.param_groups[-1]['lr']))
            logger.info('===========> Epoch: {:}, Previous best of ave: {:.2f}, Previous best of branch1: {:.2f}, Previous best of branch2: {:.2f}'.format(
                epoch, previous_best, previous_best1, previous_best2))

        total_loss, total_loss_CE, total_loss_con, total_loss_dis = 0.0, 0.0, 0.0, 0.0
        total_mask_ratio = 0.0

        trainloader_l.sampler.set_epoch(epoch)
        trainloader_u.sampler.set_epoch(epoch)

        loader = zip(trainloader_l, trainloader_u)

        total_labeled = 0
        total_unlabeled = 0

        for i, ((labeled_img, labeled_img_mask), (unlabeled_img, aug_unlabeled_img, ignore_img_mask, unlabeled_cutmix_box)) in enumerate(loader):

            labeled_img, labeled_img_mask = labeled_img.cuda(args.local_rank), labeled_img_mask.cuda(args.local_rank)
            unlabeled_img, aug_unlabeled_img, ignore_img_mask, unlabeled_cutmix_box = unlabeled_img.cuda(args.local_rank), aug_unlabeled_img.cuda(args.local_rank), ignore_img_mask.cuda(args.local_rank), unlabeled_cutmix_box.cuda(args.local_rank)

            optimizer.zero_grad()

            dist.barrier()

            num_lb, num_ulb = labeled_img.shape[0], unlabeled_img.shape[0]

            total_labeled += num_lb
            total_unlabeled += num_ulb

            # =========================================================================================
            # labeled data: labeled_img, labeled_img_mask
            # =========================================================================================
            model.train()
            labeled_logits = model(labeled_img)

            # =========================================================================================
            # unlabeled data: unlabeled_img, ignore_img_mask, cutmix_box
            # =========================================================================================
            # first feed the data into the model with no grad to get gt
            with torch.no_grad():
                model.eval()
                unlabeled_logits = model(unlabeled_img)

                unlabeled_pseudo_label1 = unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()
                unlabeled_pseudo_label2 = unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()
                unlabeled_pred1 = unlabeled_logits['pred1'].detach()
                unlabeled_pred2 = unlabeled_logits['pred2'].detach()

            # then perform cutmix
            aug_unlabeled_img_for_mix = aug_unlabeled_img.clone()
            aug_unlabeled_ignore_img_mask_for_mix = ignore_img_mask.clone()
            aug_unlabeled_pseudo_label1_for_mix = unlabeled_pseudo_label1.clone()
            aug_unlabeled_pseudo_label2_for_mix = unlabeled_pseudo_label2.clone()
            aug_unlabeled_pred1_for_mix = unlabeled_pred1.clone()
            aug_unlabeled_pred2_for_mix = unlabeled_pred2.clone()

            # initial cutmix is companied with a probablility 0.5, increase the data divergence
            aug_unlabeled_img_for_mix[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_img_for_mix.shape) == 1] = aug_unlabeled_img_for_mix.flip(0)[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_img_for_mix.shape) == 1]
            aug_unlabeled_ignore_img_mask_for_mix[unlabeled_cutmix_box == 1] = aug_unlabeled_ignore_img_mask_for_mix.flip(0)[unlabeled_cutmix_box == 1]
            aug_unlabeled_pseudo_label1_for_mix[unlabeled_cutmix_box == 1] = aug_unlabeled_pseudo_label1_for_mix.flip(0)[unlabeled_cutmix_box == 1]
            aug_unlabeled_pseudo_label2_for_mix[unlabeled_cutmix_box == 1] = aug_unlabeled_pseudo_label2_for_mix.flip(0)[unlabeled_cutmix_box == 1]
            aug_unlabeled_pred1_for_mix[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred1_for_mix.shape) == 1] = aug_unlabeled_pred1_for_mix.flip(0)[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred1_for_mix.shape) == 1]
            aug_unlabeled_pred2_for_mix[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred2_for_mix.shape) == 1] = aug_unlabeled_pred2_for_mix.flip(0)[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred2_for_mix.shape) == 1]

            # finally feed the mixed data into the model
            model.train()
            cutmixed_aug_unlabeled_logits = model(aug_unlabeled_img_for_mix)

            # one extra branch: unlabeled data
            raw_unlabeled_logits = model(unlabeled_img)


            # to count the confident predictions
            unlabeled_pred_confidence1 = raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0]
            unlabeled_pred_confidence2 = raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0]

            # =========================================================================================
            # calculate loss
            # =========================================================================================

            # -------------
            # labeled
            # -------------
            # CE loss
            labeled_pred1 = labeled_logits['pred1']
            labeled_pred2 = labeled_logits['pred2']

            loss_CE1 = criterion_l(labeled_pred1, labeled_img_mask)
            loss_CE2 = criterion_l(labeled_pred2, labeled_img_mask)

            loss_CE = (loss_CE1 + loss_CE2) / 2
            loss_CE = loss_CE * args.w_CE

            # -------------
            # unlabeled
            # -------------
            # consistency loss
            raw_unlabeled_pred1 = raw_unlabeled_logits['pred1']
            raw_unlabeled_pred2 = raw_unlabeled_logits['pred2']

            cutmixed_aug_unlabeled_pred1 = cutmixed_aug_unlabeled_logits['pred1']
            cutmixed_aug_unlabeled_pred2 = cutmixed_aug_unlabeled_logits['pred2']

            if args.mode_confident == 'normal':
                loss_con1 = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * ((aug_unlabeled_pred1_for_mix.softmax(dim=1).max(dim=1)[0] > conf_threshold) & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con1 = torch.sum(loss_con1) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * ((aug_unlabeled_pred2_for_mix.softmax(dim=1).max(dim=1)[0] > conf_threshold) & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2 = torch.sum(loss_con2) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                loss_raw_con1 = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_raw_con1 = torch.sum(loss_raw_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_raw_con2 = torch.sum(loss_raw_con2) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'soft':
                confident_pred1, confident_pred2, unconfident_pred1, unconfident_pred2 = soft_label_selection(aug_unlabeled_pred1_for_mix, aug_unlabeled_pred2_for_mix, conf_threshold)

                loss_con1_confident = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (confident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_confident = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (confident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (unconfident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (unconfident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_confident) + torch.sum(loss_con1_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_confident) + torch.sum(loss_con2_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_confident_pred1, raw_confident_pred2, raw_unconfident_pred1, raw_unconfident_pred2 = soft_label_selection(raw_unlabeled_pred1, raw_unlabeled_pred2, conf_threshold)

                loss_raw_con1_confident = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_confident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_confident = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_confident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1_unconfident = 0.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_unconfident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_unconfident = 0.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_unconfident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_confident) + torch.sum(loss_raw_con1_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_confident) + torch.sum(loss_raw_con2_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote':
                same_pred, different_pred = vote_label_selection(unlabeled_pred1, unlabeled_pred2)

                loss_con1_same = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_same = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_different = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_different = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_same_pred, raw_different_pred = vote_label_selection(raw_unlabeled_pred1, raw_unlabeled_pred2)

                loss_raw_con1_same = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))
                loss_raw_con2_same = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))

                loss_raw_con1_different = 1.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_pred & (ignore_img_mask != 255))
                loss_raw_con2_different = 1.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_pred & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_same) + torch.sum(loss_raw_con1_different)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_same) + torch.sum(loss_raw_con2_different)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_threshold':
                different1_confident, different1_else, different2_confident, different2_else = vote_threshold_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_else = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different1_else & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_else = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different2_else & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_cc = args.w_confident * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different1_confident & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_cc = args.w_confident * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different2_confident & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_else) + torch.sum(loss_con1_cc)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_else) + torch.sum(loss_con2_cc)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_different1_confident, raw_different1_else, raw_different2_confident, raw_different2_else = vote_threshold_label_selection(raw_unlabeled_pred1, raw_unlabeled_pred2, conf_threshold)

                loss_raw_con1_else = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different1_else & (ignore_img_mask != 255))
                loss_raw_con2_else = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different2_else & (ignore_img_mask != 255))

                loss_raw_con1_cc = args.w_confident * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different1_confident & (ignore_img_mask != 255))
                loss_raw_con2_cc = args.w_confident * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different2_confident & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_else) + torch.sum(loss_raw_con1_cc)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_else) + torch.sum(loss_raw_con2_cc)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_soft':
                same_pred, different_confident_pred1, different_confident_pred2, different_unconfident_pred1, different_unconfident_pred2 = vote_soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_same = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_same = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_different_confident = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different_confident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_different_confident = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different_confident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_different_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different_unconfident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_different_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different_unconfident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different_confident) + torch.sum(loss_con1_different_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different_confident) + torch.sum(loss_con2_different_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_same_pred, raw_different_confident_pred1, raw_different_confident_pred2, raw_different_unconfident_pred1, raw_different_unconfident_pred2 = vote_soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_raw_con1_same = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))
                loss_raw_con2_same = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))

                loss_raw_con1_different_confident = 1.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_confident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_different_confident = 1.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_confident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1_different_unconfident = 0.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_unconfident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_different_unconfident = 0.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_unconfident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_same) + torch.sum(loss_raw_con1_different_confident) + torch.sum(loss_raw_con1_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_same) + torch.sum(loss_raw_con2_different_confident) + torch.sum(loss_raw_con2_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            else:
                loss_con1 = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (aug_unlabeled_ignore_img_mask_for_mix != 255)
                loss_con1 = torch.sum(loss_con1) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (aug_unlabeled_ignore_img_mask_for_mix != 255)
                loss_con2 = torch.sum(loss_con2) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                loss_raw_con1 = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_raw_con1 = torch.sum(loss_raw_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_raw_con2 = torch.sum(loss_raw_con2) / torch.sum(ignore_img_mask != 255).item()

            loss_con = (loss_con1 + loss_con2 + loss_raw_con1 + loss_raw_con2) / 4
            loss_con = loss_con * args.w_con

            # -------------
            # both
            # -------------
            # discrepancy loss
            cos_dis = nn.CosineSimilarity(dim=1, eps=1e-6)

            # labeled
            labeled_feature1 = labeled_logits['feature1']
            labeled_feature2 = labeled_logits['feature2']
            loss_dis_labeled1 = 1 + cos_dis(labeled_feature1.detach(), labeled_feature2).mean()
            loss_dis_labeled2 = 1 + cos_dis(labeled_feature2.detach(), labeled_feature1).mean()
            loss_dis_labeled = (loss_dis_labeled1 + loss_dis_labeled2) / 2

            # unlabeled
            cutmixed_aug_unlabeled_feature1 = cutmixed_aug_unlabeled_logits['feature1']
            cutmixed_aug_unlabeled_feature2 = cutmixed_aug_unlabeled_logits['feature2']
            loss_dis_cutmixed_aug_unlabeled1 = 1 + cos_dis(cutmixed_aug_unlabeled_feature1.detach(), cutmixed_aug_unlabeled_feature2).mean()
            loss_dis_cutmixed_aug_unlabeled2 = 1 + cos_dis(cutmixed_aug_unlabeled_feature2.detach(), cutmixed_aug_unlabeled_feature1).mean()
            loss_dis_cutmixed_aug_unlabeled = (loss_dis_cutmixed_aug_unlabeled1 + loss_dis_cutmixed_aug_unlabeled2) / 2

            raw_unlabeled_feature1 = raw_unlabeled_logits['feature1']
            raw_unlabeled_feature2 = raw_unlabeled_logits['feature2']
            loss_dis_raw_unlabeled1 = 1 + cos_dis(raw_unlabeled_feature1.detach(), raw_unlabeled_feature2).mean()
            loss_dis_raw_unlabeled2 = 1 + cos_dis(raw_unlabeled_feature2.detach(), raw_unlabeled_feature1).mean()
            loss_dis_raw_unlabeled = (loss_dis_raw_unlabeled1 + loss_dis_raw_unlabeled2) / 2

            loss_dis_unlabeled = (loss_dis_cutmixed_aug_unlabeled + loss_dis_raw_unlabeled) / 2

            loss_dis = (loss_dis_labeled + loss_dis_unlabeled) / 2
            loss_dis = loss_dis * args.w_dis

            # -------------
            # total
            # -------------
            loss = loss_CE + loss_con + loss_dis

            dist.barrier()

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_loss_CE += loss_CE.item()
            total_loss_con += loss_con.item()
            total_loss_dis += loss_dis.item()

            total_confident = (((unlabeled_pred_confidence1 >= 0.95) & (ignore_img_mask != 255)).sum().item() + ((unlabeled_pred_confidence2 >= 0.95) & (ignore_img_mask != 255)).sum().item()) / 2
            total_mask_ratio += total_confident / (ignore_img_mask != 255).sum().item()

            iters = epoch * len(trainloader_u) + i

            # update lr

            backbone_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            backbone_lr = backbone_lr * args.lr_backbone

            seg_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            seg_lr = seg_lr * args.lr_network
                
            optimizer.param_groups[0]['lr'] = backbone_lr
            optimizer.param_groups[1]['lr'] = backbone_lr
            for ii in range(2, len(optimizer.param_groups)):
                optimizer.param_groups[ii]['lr'] = seg_lr

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                tb.add_scalar('train_loss_total', total_loss / (i+1), iters)
                tb.add_scalar('train_loss_CE', total_loss_CE / (i+1), iters)
                tb.add_scalar('train_loss_con', total_loss_con / (i+1), iters)
                tb.add_scalar('train_loss_dis', total_loss_dis / (i+1), iters)

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                logger.info('Iters: {:}, Total loss: {:.3f}, Loss CE: {:.3f}, '
                            'Loss consistency: {:.3f}, Loss discrepancy: {:.3f}, Mask: {:.3f}'.format(
                    i, total_loss / (i+1), total_loss_CE / (i+1), total_loss_con / (i+1), total_loss_dis / (i+1), 
                    total_mask_ratio / (i+1)))

        if args.use_SPL:
            conf_threshold += 0.01
            if conf_threshold >= 0.95:
                conf_threshold = 0.95

        if cfg['dataset'] == 'cityscapes':
            eval_mode = 'center_crop' if epoch < args.epochs - 20 else 'sliding_window'
        else:
            eval_mode = 'original'
        
        dist.barrier()

        # test with different branches  
        if args.local_rank <= 0:
            if epoch == 4:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=5)
            elif epoch == 9:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=10)
            elif epoch == 19:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=20)
            elif epoch == 39:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=40)
            else:
                evaluate_result = evaluate(args.local_rank, model, valloader, eval_mode, args, cfg)

            mIOU1 = evaluate_result['IOU1']
            mIOU2 = evaluate_result['IOU2']
            mIOU_ave = evaluate_result['IOU_ave']

            tb.add_scalar('meanIOU_branch1', mIOU1, epoch)
            tb.add_scalar('meanIOU_branch2', mIOU2, epoch)
            tb.add_scalar('meanIOU_ave', mIOU_ave, epoch)

            logger.info('***** Evaluation with branch 1 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU1))
            logger.info('***** Evaluation with branch 2 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU2))
            logger.info('***** Evaluation with two branches {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU_ave))

            if mIOU1 > previous_best1:
                if previous_best1 != 0:
                    os.remove(os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, previous_best1)))
                previous_best1 = mIOU1
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, mIOU1)))
            
            if mIOU2 > previous_best2:
                if previous_best2 != 0:
                    os.remove(os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, previous_best2)))
                previous_best2 = mIOU2
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, mIOU2)))

            if mIOU_ave > previous_best:
                if previous_best != 0:
                    os.remove(os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, previous_best)))
                previous_best = mIOU_ave
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, mIOU_ave)))

Guess you like

Origin blog.csdn.net/m0_61899108/article/details/131148390