【CVPR2023】Consistência de visão cruzada baseada em conflito para segmentação semântica semi-supervisionada

Consistência de visão cruzada baseada em conflito para segmentação semântica semi-supervisionada, CVPR2023

Artigo: https://arxiv.org/abs/2303.01276

Código: https://github.com/xiaoyao3302/CCVC/

Resumo

A segmentação semântica semi-supervisionada (SSS) pode reduzir a necessidade de dados de treinamento totalmente anotados em grande escala. Os métodos existentes geralmente sofrem de viés de confirmação ao lidar com pseudo-rótulos, que podem ser mitigados por uma estrutura de treinamento conjunta. Os métodos SSS atuais baseados em treinamento conjunto dependem de perturbações feitas à mão para evitar que diferentes sub-redes entrem em colapso, mas perturbações artificiais são difíceis de obter soluções ótimas. Neste artigo, propomos um novo método de consistência de visão cruzada baseado em conflito (CCVC) baseado em uma estrutura de treinamento conjunto de duas ramificações, que visa forçar duas sub-redes a aprender recursos informativos de visualizações não relacionadas. Primeiro, propomos uma nova estratégia de consistência de visão cruzada (CVC), que incentiva duas sub-redes a aprender recursos diferentes da mesma entrada, introduzindo uma perda de diferença de recurso, enquanto espera-se que esses recursos descontínuos gerem pontuações de previsão consistentes para a entrada. A estratégia CVC ajuda a evitar que ambas as sub-redes entrem em colapso. Além disso, um método de pseudo-rotulagem baseado em conflito (CPL) é proposto para garantir que o modelo aprenda informações mais úteis das previsões de conflito para garantir a estabilidade do processo de treinamento.

introdução

A segmentação semântica totalmente supervisionada requer muito esforço para coletar dados de anotação precisos. O aprendizado semi-supervisionado pode alcançar a segmentação semântica usando uma pequena quantidade de dados rotulados e uma grande quantidade de dados não rotulados. Mas como usar dados não rotulados para auxiliar dados rotulados para treinamento de modelo é uma questão fundamental.

O método de uso de pseudo-rótulos pode ser afetado por viés de confirmação, o que causará degradação do desempenho devido ao treinamento instável. Abordagens baseadas na regularização de consistência mostram melhor desempenho, mas a maioria depende de previsões de entradas fracamente perturbadas para gerar pseudo-rótulos, que são então usados ​​como supervisão para previsões de entradas fortemente perturbadas. Também é afetado pelo viés de confirmação .

O co-treinamento permite que diferentes sub-redes infiram a mesma instância a partir de diferentes visualizações e transfiram o conhecimento aprendido de uma visualização para outra por meio de pseudo-rotulagem. Em particular, o co-treinamento depende de referências de múltiplas visualizações para aumentar a percepção do modelo e, assim, melhorar a confiabilidade dos pseudo-rótulos gerados. A chave é como evitar que diferentes sub-redes entrem em colapso para que o modelo possa fazer previsões corretas com base nas entradas de diferentes visualizações. No entanto, a perturbação manual usada na maioria dos métodos SSS não garante o aprendizado de recursos heterogêneos, evitando efetivamente que as sub-redes entrem em colapso.

Diante dos problemas acima, este artigo propõe uma nova estratégia de consistência de visão cruzada baseada em conflito (CCVC) para SSS, que garante que as duas sub-redes no modelo possam aprender diferentes previsões confiáveis ​​das visões de , para treinamento conjunto , o que aumenta ainda mais permite que cada sub-rede faça previsões confiáveis ​​e significativas.

  • Em primeiro lugar, um método de consistência de visão cruzada (CVC) com perda de diferença é proposto para minimizar a semelhança entre os recursos extraídos por duas sub-redes, incentivando-os a extrair recursos diferentes, evitando assim que as duas sub-redes entrem em colapso.
  • Em seguida, o conhecimento aprendido de uma sub-rede é transferido para outra usando pseudo-rotulagem cruzada para melhorar a conscientização da rede para raciocinar corretamente sobre a mesma entrada de diferentes visualizações, resultando em previsões mais confiáveis.
  • No entanto, a perda de diferença pode introduzir uma perturbação muito forte no modelo, de modo que os recursos extraídos pela sub-rede possam conter informações menos significativas para a previsão, resultando em previsões inconsistentes e não confiáveis ​​das duas sub-redes. Isso levará ao problema do viés de confirmação, que prejudicará o treinamento conjunto das sub-redes. Para resolver esse problema, um método de pseudo-rótulo baseado em colisão (CPL) é proposto para encorajar os pseudo-rótulos gerados pelas previsões conflitantes de cada sub-rede para fornecer supervisão mais forte nas previsões de cada um para impor previsões consistentes de ambas as sub-redes. , para preservar a previsão e a confiabilidade da previsão. Dessa forma, espera-se reduzir o impacto do viés de confirmação e tornar o processo de treinamento mais estável.

Conforme mostrado na Figura 1, os escores de similaridade entre as características extraídas das duas sub-redes do modelo de regularização de consistência cruzada (CCR) são mantidos em um nível alto, o que indica que a perspectiva de inferência do CCR está um pouco relacionada. Em contraste, a abordagem CVC garante que as visualizações reanalisadas sejam suficientemente diferentes para produzir previsões mais confiáveis.

figura 1. São comparados os valores de similaridade de cosseno entre as características extraídas pelas duas sub-redes do método tradicional de regularização de consistência cruzada (CCR) e nosso método CVC. A precisão da previsão dos dois métodos medidos com mIoU também é comparada. Os métodos CVC podem impedir que duas sub-redes entrem em colapso e inferir entradas de visualizações irrelevantes, enquanto o CCR não pode garantir que as visualizações introduzidas sejam diferentes. Mostra-se que o CVC pode aumentar a percepção do modelo e, assim, produzir previsões mais confiáveis.

 

Trabalho relatado

segmentação semântica

Segmentação semântica semi-supervisionada

treinamento colaborativo

método

CCVC ( consistência de visão cruzada baseada em conflito )

Consistência entre visualizações

Abordagem Cross-View Consistency (CVC). É utilizada uma rede biramal baseada em treinamento conjunto, onde as duas sub-redes (Ψ1 e Ψ2), possuem arquiteturas semelhantes, mas não compartilham parâmetros. Divida cada sub-rede em extratores e classificadores de recursos. O objetivo é permitir que as duas sub-redes raciocinem sobre as entradas de diferentes visualizações, portanto, os recursos extraídos devem ser diferentes. Portanto, a similaridade de cosseno entre as características extraídas pelos dois extratores de características é minimizada.

Observe que o fator de 1 é para garantir que o valor da perda seja sempre não negativo. As duas sub-redes são encorajadas a produzir recursos que não têm relacionamento comum, forçando assim as duas sub-redes a aprender a raciocinar sobre as entradas de duas visualizações não relacionadas.

A maioria dos métodos SSS usa ResNet pré-treinado no ImageNet como o backbone do DeepLabv3+ e apenas ajusta o backbone com uma pequena taxa de aprendizado, o que dificulta a operação de maximização da diferença de recursos neste artigo. Para resolver esse problema, este artigo atinge a heterogeneidade da rede usando camadas convolucionais para mapear os recursos extraídos para outro espaço de recursos. A perda por diferença é reescrita como:

Ambos os dados rotulados e não rotulados são supervisionados com perda de diferença. A perda total da diferença é: 

Os dados rotulados usam rótulos de verdade como supervisão e ambas as sub-redes são supervisionadas. A perda supervisionada para dados rotulados é:

A pseudo-rotulação de dados não rotulados permite que cada sub-rede aprenda informações semânticas da outra. Aplique a perda de entropia cruzada para ajustar o modelo. As duas sub-redes supervisionam uma à outra como pseudo-rótulos. A perda de dados não rotulados é (perda de consistência cruzada):

 

A perda total de toda a rede: a soma ponderada da perda de supervisão, perda de consistência e perda de diferença. 

 

 Pseudo-rotulação baseada em conflito

Usando o método de consistência de visualização cruzada (CVC), duas sub-redes aprenderão informações semânticas de visualizações diferentes. No entanto, se a perda de diferença de características introduzir uma perturbação muito forte no modelo, o treinamento não será estável o suficiente. Portanto, é difícil garantir que as duas sub-redes possam aprender informações semânticas úteis uma da outra, o que pode afetar ainda mais a confiabilidade das previsões.

Portanto, o artigo propõe um método de pseudo-rotulagem baseado em conflito (CPL), que permite que as duas sub-redes aprendam mais informações semânticas de previsões conflitantes para fazer previsões consistentes, garantindo assim que as duas sub-redes possam gerar as mesmas previsões confiáveis ​​e maior estabilização de treinamento. Use o valor binário δ para definir se as previsões entram em conflito.


O objetivo é encorajar o modelo a aprender mais informações semânticas dessas previsões conflitantes. Portanto, quando essas previsões são usadas para gerar pseudo-rótulos para ajuste fino do modelo, maior peso ωc é atribuído à perda de entropia cruzada supervisionada por esses pseudo-rótulos.

No entanto, durante o treinamento, o treinamento também pode sofrer viés de confirmação, pois alguns pseudo-rótulos podem estar errados. O artigo ainda divide as previsões de conflito em duas categorias, previsões conflitantes e confiáveis ​​(CC), previsões conflitantes, mas não confiáveis ​​(CU) , e apenas atribui ωc a pseudo-rótulos gerados por previsões CC.

Definir previsões CC usando valores binários \delta ^{cc}_{mn,i}. Use \delta ^{e}_{mn,i}a união que representa previsões de CU e previsões sem conflito

Os pseudo-rótulos gerados pelas previsões CU ainda são usados ​​para ajustar o modelo com pesos normais em vez de descartá-los diretamente, uma vez que essas previsões CU também podem conter informações latentes sobre relacionamentos interclasses. A perda de consistência é reescrita como:

em, 

 

 

O método CCVC pode efetivamente encorajar duas sub-redes a raciocinar sobre a mesma entrada de diferentes perspectivas, e a transferência de conhecimento entre as duas sub-redes pode aumentar a percepção de cada sub-rede, melhorando assim a confiabilidade das previsões.

Na fase de inferência, apenas um ramo da rede é necessário para gerar previsões.

experimentar

 

 

Código chave

CCVC_no_aug.py

# https://github.com/xiaoyao3302/CCVC/blob/master/CCVC_no_aug.py

args = parser.parse_args()

args.world_size = args.gpus * args.nodes
args.ddp = True if args.gpus > 1 else False

def main(gpu, ngpus_per_node, cfg, args):

    args.local_rank = gpu

    if args.local_rank <= 0:
        os.makedirs(args.save_path, exist_ok=True)
        
    logger = init_log_save(args.save_path, 'global', logging.INFO)
    logger.propagate = 0

    if args.local_rank <= 0:
        tb_dir = args.save_path
        tb = SummaryWriter(log_dir=tb_dir)

    if args.ddp:
        dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.world_size)
        # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=args.world_size)

    if args.local_rank <= 0:
        logger.info('{}\n'.format(pprint.pformat(cfg)))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = True

    model = Discrepancy_DeepLabV3Plus(args, cfg)
    if args.local_rank <= 0:
        logger.info('Total params: {:.1f}M\n'.format(count_params(model)))

    # pdb.set_trace()
    # TODO: check the parameter !!!

    optimizer = SGD([{'params': model.branch1.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone}, 
                     {'params': model.branch2.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone},
                     {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                      'lr': args.base_lr * args.lr_network}], lr=args.base_lr, momentum=0.9, weight_decay=1e-4)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=False)

    # # try to load saved model
    # try:   
    #     model.load_model(args.load_path)
    #     if args.local_rank <= 0:
    #         logger.info('load saved model')
    # except:
    #     if args.local_rank <= 0:
    #         logger.info('no saved model')

    # ---- #
    # loss #
    # ---- #
    # CE loss for labeled data
    criterion_l = nn.CrossEntropyLoss(reduction='mean', ignore_index=255).cuda(args.local_rank)
    
    # consistency loss for unlabeled data
    criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(args.local_rank)

    # ------- #
    # dataset #
    # ------- #
    trainset_u = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_u',
                             args.crop_size, args.unlabeled_id_path)
    trainset_l = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_l',
                             args.crop_size, args.labeled_id_path, nsample=len(trainset_u.ids))
    valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

    if args.ddp:
        trainsampler_l = torch.utils.data.distributed.DistributedSampler(trainset_l)
    else:
        trainsampler_l = None
    trainloader_l = DataLoader(trainset_l, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_l)

    if args.ddp:
        trainsampler_u = torch.utils.data.distributed.DistributedSampler(trainset_u)
    else:
        trainsampler_u = None
    
    trainloader_u = DataLoader(trainset_u, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_u)

    valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False)
    # if args.ddp:
    #     valsampler = torch.utils.data.distributed.DistributedSampler(valset)
    # else:
    #     valsampler = None
    
    # valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False, sampler=valsampler)

    total_iters = len(trainloader_u) * args.epochs
    previous_best = 0.0
    previous_best1 = 0.0
    previous_best2 = 0.0

    # can change with epochs, add SPL here
    conf_threshold = args.conf_threshold

    for epoch in range(args.epochs):
        if args.local_rank <= 0:
            logger.info('===========> Epoch: {:}, backbone1 LR: {:.4f}, backbone2 LR: {:.4f}, segmentation LR: {:.4f}'.format(
                epoch, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'], optimizer.param_groups[-1]['lr']))
            logger.info('===========> Epoch: {:}, Previous best of ave: {:.2f}, Previous best of branch1: {:.2f}, Previous best of branch2: {:.2f}'.format(
                epoch, previous_best, previous_best1, previous_best2))

        total_loss, total_loss_CE, total_loss_con, total_loss_dis = 0.0, 0.0, 0.0, 0.0
        total_mask_ratio = 0.0

        trainloader_l.sampler.set_epoch(epoch)
        trainloader_u.sampler.set_epoch(epoch)

        loader = zip(trainloader_l, trainloader_u)

        total_labeled = 0
        total_unlabeled = 0

        for i, ((labeled_img, labeled_img_mask), (unlabeled_img, ignore_img_mask, cutmix_box)) in enumerate(loader):

            labeled_img, labeled_img_mask = labeled_img.cuda(args.local_rank), labeled_img_mask.cuda(args.local_rank)
            unlabeled_img, ignore_img_mask, cutmix_box = unlabeled_img.cuda(args.local_rank), ignore_img_mask.cuda(args.local_rank), cutmix_box.cuda(args.local_rank)

            model.train()

            optimizer.zero_grad()

            dist.barrier()

            num_lb, num_ulb = labeled_img.shape[0], unlabeled_img.shape[0]

            total_labeled += num_lb
            total_unlabeled += num_ulb

            # =========================================================================================
            # labeled data: labeled_img, labeled_img_mask
            # =========================================================================================
            labeled_logits = model(labeled_img)

            # =========================================================================================
            # unlabeled data: unlabeled_img, ignore_img_mask, cutmix_box
            # =========================================================================================
            unlabeled_logits = model(unlabeled_img)

            # to count the confident predictions
            unlabeled_pred_confidence1 = unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0]
            unlabeled_pred_confidence2 = unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0]

            # =========================================================================================
            # calculate loss
            # =========================================================================================

            # -------------
            # labeled
            # -------------
            # CE loss
            labeled_pred1 = labeled_logits['pred1']
            labeled_pred2 = labeled_logits['pred2']

            loss_CE1 = criterion_l(labeled_pred1, labeled_img_mask)
            loss_CE2 = criterion_l(labeled_pred2, labeled_img_mask)

            loss_CE = (loss_CE1 + loss_CE2) / 2
            loss_CE = loss_CE * args.w_CE

            # -------------
            # unlabeled
            # -------------
            # consistency loss
            unlabeled_pred1 = unlabeled_logits['pred1']
            unlabeled_pred2 = unlabeled_logits['pred2']

            if args.mode_confident == 'normal':
                loss_con1 = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_con1 = torch.sum(loss_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_con2 = torch.sum(loss_con2) / torch.sum(ignore_img_mask != 255).item()
                
            elif args.mode_confident == 'soft':
                confident_pred1, confident_pred2, unconfident_pred1, unconfident_pred2 = soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_confident = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (confident_pred1 & (ignore_img_mask != 255))
                loss_con2_confident = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (confident_pred2 & (ignore_img_mask != 255))

                loss_con1_unconfident = args.w_unconfident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (unconfident_pred1 & (ignore_img_mask != 255))
                loss_con2_unconfident = args.w_unconfident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (unconfident_pred2 & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_confident) + torch.sum(loss_con1_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_confident) + torch.sum(loss_con2_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote':
                same_pred, different_pred = vote_label_selection(unlabeled_pred1, unlabeled_pred2)

                loss_con1_same = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))
                loss_con2_same = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))

                loss_con1_different = args.w_confident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_pred & (ignore_img_mask != 255))
                loss_con2_different = args.w_confident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_pred & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_threshold':
                different1_confident, different1_else, different2_confident, different2_else = vote_threshold_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_else = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different1_else & (ignore_img_mask != 255))
                loss_con2_else = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different2_else & (ignore_img_mask != 255))

                loss_con1_cc = args.w_confident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different1_confident & (ignore_img_mask != 255))
                loss_con2_cc = args.w_confident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different2_confident & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_else) + torch.sum(loss_con1_cc)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_else) + torch.sum(loss_con2_cc)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_soft':
                same_pred, different_confident_pred1, different_confident_pred2, different_unconfident_pred1, different_unconfident_pred2 = vote_soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_same = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))
                loss_con2_same = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (same_pred & (ignore_img_mask != 255))

                loss_con1_different_confident = args.w_confident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_confident_pred1 & (ignore_img_mask != 255))
                loss_con2_different_confident = args.w_confident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_confident_pred2 & (ignore_img_mask != 255))

                loss_con1_different_unconfident = args.w_unconfident * criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_unconfident_pred1 & (ignore_img_mask != 255))
                loss_con2_different_unconfident = args.w_unconfident * criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (different_unconfident_pred2 & (ignore_img_mask != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different_confident) + torch.sum(loss_con1_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different_confident) + torch.sum(loss_con2_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            else:
                loss_con1 = criterion_u(unlabeled_pred2, unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_con1 = torch.sum(loss_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_con2 = criterion_u(unlabeled_pred1, unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_con2 = torch.sum(loss_con2) / torch.sum(ignore_img_mask != 255).item()

            loss_con = (loss_con1 + loss_con2) / 2
            loss_con = loss_con * args.w_con

            # -------------
            # both
            # -------------
            # discrepancy loss
            cos_dis = nn.CosineSimilarity(dim=1, eps=1e-6)

            # labeled
            labeled_feature1 = labeled_logits['feature1']
            labeled_feature2 = labeled_logits['feature2']
            loss_dis_labeled1 = 1 + cos_dis(labeled_feature1.detach(), labeled_feature2).mean()
            loss_dis_labeled2 = 1 + cos_dis(labeled_feature2.detach(), labeled_feature1).mean()
            loss_dis_labeled = (loss_dis_labeled1 + loss_dis_labeled2) / 2

            # unlabeled
            unlabeled_feature1 = unlabeled_logits['feature1']
            unlabeled_feature2 = unlabeled_logits['feature2']
            loss_dis_unlabeled1 = 1 + cos_dis(unlabeled_feature1.detach(), unlabeled_feature2).mean()
            loss_dis_unlabeled2 = 1 + cos_dis(unlabeled_feature2.detach(), unlabeled_feature1).mean()
            loss_dis_unlabeled = (loss_dis_unlabeled1 + loss_dis_unlabeled2) / 2

            loss_dis = (loss_dis_labeled + loss_dis_unlabeled) / 2
            loss_dis = loss_dis * args.w_dis

            # -------------
            # total
            # -------------
            loss = loss_CE
            if args.use_con:
                loss = loss + loss_con
            if args.use_dis:
                loss = loss + loss_dis

            dist.barrier()

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_loss_CE += loss_CE.item()
            total_loss_con += loss_con.item()
            total_loss_dis += loss_dis.item()

            total_confident = (((unlabeled_pred_confidence1 >= 0.95) & (ignore_img_mask != 255)).sum().item() + ((unlabeled_pred_confidence2 >= 0.95) & (ignore_img_mask != 255)).sum().item()) / 2
            total_mask_ratio += total_confident / (ignore_img_mask != 255).sum().item()

            iters = epoch * len(trainloader_u) + i

            # update lr

            backbone_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            backbone_lr = backbone_lr * args.lr_backbone

            seg_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            seg_lr = seg_lr * args.lr_network
                
            optimizer.param_groups[0]["lr"] = backbone_lr
            optimizer.param_groups[1]["lr"] = backbone_lr
            for ii in range(2, len(optimizer.param_groups)):
                optimizer.param_groups[ii]['lr'] = seg_lr

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                tb.add_scalar('train_loss_total', total_loss / (i+1), iters)
                tb.add_scalar('train_loss_CE', total_loss_CE / (i+1), iters)
                tb.add_scalar('train_loss_con', total_loss_con / (i+1), iters)
                tb.add_scalar('train_loss_dis', total_loss_dis / (i+1), iters)

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                logger.info('Iters: {:}, Total loss: {:.3f}, Loss CE: {:.3f}, '
                            'Loss consistency: {:.3f}, Loss discrepancy: {:.3f}, Mask: {:.3f}'.format(
                    i, total_loss / (i+1), total_loss_CE / (i+1), total_loss_con / (i+1), total_loss_dis / (i+1), 
                    total_mask_ratio / (i+1)))

        if args.use_SPL:
            conf_threshold += 0.01
            if conf_threshold >= 0.95:
                conf_threshold = 0.95

        if cfg['dataset'] == 'cityscapes':
            eval_mode = 'center_crop' if epoch < args.epochs - 20 else 'sliding_window'
        else:
            eval_mode = 'original'
        
        dist.barrier()

        # test with different branches
        if args.local_rank <= 0:
            if epoch == 4:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=5)
            elif epoch == 9:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=10)
            elif epoch == 19:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=20)
            elif epoch == 39:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=40)
            else:
                evaluate_result = evaluate(args.local_rank, model, valloader, eval_mode, args, cfg)


            mIOU1 = evaluate_result['IOU1']
            mIOU2 = evaluate_result['IOU2']
            mIOU_ave = evaluate_result['IOU_ave']

            tb.add_scalar('meanIOU_branch1', mIOU1, epoch)
            tb.add_scalar('meanIOU_branch2', mIOU2, epoch)
            tb.add_scalar('meanIOU_ave', mIOU_ave, epoch)

            logger.info('***** Evaluation with branch 1 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU1))
            logger.info('***** Evaluation with branch 2 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU2))
            logger.info('***** Evaluation with two branches {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU_ave))

            if mIOU1 > previous_best1:
                if previous_best1 != 0:
                    os.remove(os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, previous_best1)))
                previous_best1 = mIOU1
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, mIOU1)))
            
            if mIOU2 > previous_best2:
                if previous_best2 != 0:
                    os.remove(os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, previous_best2)))
                previous_best2 = mIOU2
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, mIOU2)))

            if mIOU_ave > previous_best:
                if previous_best != 0:
                    os.remove(os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, previous_best)))
                previous_best = mIOU_ave
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, mIOU_ave)))

CCVC_ago.py

# https://github.com/xiaoyao3302/CCVC/blob/master/CCVC_aug.py

args = parser.parse_args()

args.world_size = args.gpus * args.nodes
args.ddp = True if args.gpus > 1 else False

def main(gpu, ngpus_per_node, cfg, args):

    args.local_rank = gpu

    if args.local_rank <= 0:
        os.makedirs(args.save_path, exist_ok=True)
        
    logger = init_log_save(args.save_path, 'global', logging.INFO)
    logger.propagate = 0

    if args.local_rank <= 0:
        tb_dir = args.save_path
        tb = SummaryWriter(log_dir=tb_dir)

    if args.ddp:
        dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.world_size)
        # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=args.world_size)

    if args.local_rank <= 0:
        logger.info('{}\n'.format(pprint.pformat(cfg)))

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = True

    model = Discrepancy_DeepLabV3Plus(args, cfg)
    if args.local_rank <= 0:
        logger.info('Total params: {:.1f}M\n'.format(count_params(model)))

    # pdb.set_trace()
    # TODO: check the parameter !!!

    optimizer = SGD([{'params': model.branch1.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone}, 
                     {'params': model.branch2.backbone.parameters(), 'lr': args.base_lr * args.lr_backbone},
                     {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                      'lr': args.base_lr * args.lr_network}], lr=args.base_lr, momentum=0.9, weight_decay=1e-4)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=False)

    # # try to load saved model
    # try:   
    #     model.load_model(args.load_path)
    #     if args.local_rank <= 0:
    #         logger.info('load saved model')
    # except:
    #     if args.local_rank <= 0:
    #         logger.info('no saved model')

    # ---- #
    # loss #
    # ---- #
    # CE loss for labeled data
    if args.mode_criterion == 'CE':
        criterion_l = nn.CrossEntropyLoss(reduction='mean', ignore_index=255).cuda(args.local_rank)
    else:
        criterion_l = ProbOhemCrossEntropy2d(**cfg['criterion']['kwargs']).cuda(args.local_rank)
    
    # consistency loss for unlabeled data
    criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(args.local_rank)

    # ------- #
    # dataset #
    # ------- #
    trainset_u = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_u',
                             args.crop_size, args.unlabeled_id_path)
    trainset_l = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_l',
                             args.crop_size, args.labeled_id_path, nsample=len(trainset_u.ids))
    valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

    if args.ddp:
        trainsampler_l = torch.utils.data.distributed.DistributedSampler(trainset_l)
    else:
        trainsampler_l = None
    trainloader_l = DataLoader(trainset_l, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_l)

    if args.ddp:
        trainsampler_u = torch.utils.data.distributed.DistributedSampler(trainset_u)
    else:
        trainsampler_u = None
    
    trainloader_u = DataLoader(trainset_u, batch_size=args.batch_size,
                               pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=trainsampler_u)

    valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False)
    # if args.ddp:
    #     valsampler = torch.utils.data.distributed.DistributedSampler(valset)
    # else:
    #     valsampler = None
    
    # valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=args.num_workers, drop_last=False, sampler=valsampler)

    total_iters = len(trainloader_u) * args.epochs
    previous_best = 0.0
    previous_best1 = 0.0
    previous_best2 = 0.0

    # can change with epochs, add SPL here
    conf_threshold = args.conf_threshold

    for epoch in range(args.epochs):
        if args.local_rank <= 0:
            logger.info('===========> Epoch: {:}, backbone1 LR: {:.4f}, backbone2 LR: {:.4f}, segmentation LR: {:.4f}'.format(
                epoch, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'], optimizer.param_groups[-1]['lr']))
            logger.info('===========> Epoch: {:}, Previous best of ave: {:.2f}, Previous best of branch1: {:.2f}, Previous best of branch2: {:.2f}'.format(
                epoch, previous_best, previous_best1, previous_best2))

        total_loss, total_loss_CE, total_loss_con, total_loss_dis = 0.0, 0.0, 0.0, 0.0
        total_mask_ratio = 0.0

        trainloader_l.sampler.set_epoch(epoch)
        trainloader_u.sampler.set_epoch(epoch)

        loader = zip(trainloader_l, trainloader_u)

        total_labeled = 0
        total_unlabeled = 0

        for i, ((labeled_img, labeled_img_mask), (unlabeled_img, aug_unlabeled_img, ignore_img_mask, unlabeled_cutmix_box)) in enumerate(loader):

            labeled_img, labeled_img_mask = labeled_img.cuda(args.local_rank), labeled_img_mask.cuda(args.local_rank)
            unlabeled_img, aug_unlabeled_img, ignore_img_mask, unlabeled_cutmix_box = unlabeled_img.cuda(args.local_rank), aug_unlabeled_img.cuda(args.local_rank), ignore_img_mask.cuda(args.local_rank), unlabeled_cutmix_box.cuda(args.local_rank)

            optimizer.zero_grad()

            dist.barrier()

            num_lb, num_ulb = labeled_img.shape[0], unlabeled_img.shape[0]

            total_labeled += num_lb
            total_unlabeled += num_ulb

            # =========================================================================================
            # labeled data: labeled_img, labeled_img_mask
            # =========================================================================================
            model.train()
            labeled_logits = model(labeled_img)

            # =========================================================================================
            # unlabeled data: unlabeled_img, ignore_img_mask, cutmix_box
            # =========================================================================================
            # first feed the data into the model with no grad to get gt
            with torch.no_grad():
                model.eval()
                unlabeled_logits = model(unlabeled_img)

                unlabeled_pseudo_label1 = unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()
                unlabeled_pseudo_label2 = unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()
                unlabeled_pred1 = unlabeled_logits['pred1'].detach()
                unlabeled_pred2 = unlabeled_logits['pred2'].detach()

            # then perform cutmix
            aug_unlabeled_img_for_mix = aug_unlabeled_img.clone()
            aug_unlabeled_ignore_img_mask_for_mix = ignore_img_mask.clone()
            aug_unlabeled_pseudo_label1_for_mix = unlabeled_pseudo_label1.clone()
            aug_unlabeled_pseudo_label2_for_mix = unlabeled_pseudo_label2.clone()
            aug_unlabeled_pred1_for_mix = unlabeled_pred1.clone()
            aug_unlabeled_pred2_for_mix = unlabeled_pred2.clone()

            # initial cutmix is companied with a probablility 0.5, increase the data divergence
            aug_unlabeled_img_for_mix[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_img_for_mix.shape) == 1] = aug_unlabeled_img_for_mix.flip(0)[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_img_for_mix.shape) == 1]
            aug_unlabeled_ignore_img_mask_for_mix[unlabeled_cutmix_box == 1] = aug_unlabeled_ignore_img_mask_for_mix.flip(0)[unlabeled_cutmix_box == 1]
            aug_unlabeled_pseudo_label1_for_mix[unlabeled_cutmix_box == 1] = aug_unlabeled_pseudo_label1_for_mix.flip(0)[unlabeled_cutmix_box == 1]
            aug_unlabeled_pseudo_label2_for_mix[unlabeled_cutmix_box == 1] = aug_unlabeled_pseudo_label2_for_mix.flip(0)[unlabeled_cutmix_box == 1]
            aug_unlabeled_pred1_for_mix[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred1_for_mix.shape) == 1] = aug_unlabeled_pred1_for_mix.flip(0)[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred1_for_mix.shape) == 1]
            aug_unlabeled_pred2_for_mix[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred2_for_mix.shape) == 1] = aug_unlabeled_pred2_for_mix.flip(0)[unlabeled_cutmix_box.unsqueeze(1).expand(aug_unlabeled_pred2_for_mix.shape) == 1]

            # finally feed the mixed data into the model
            model.train()
            cutmixed_aug_unlabeled_logits = model(aug_unlabeled_img_for_mix)

            # one extra branch: unlabeled data
            raw_unlabeled_logits = model(unlabeled_img)


            # to count the confident predictions
            unlabeled_pred_confidence1 = raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0]
            unlabeled_pred_confidence2 = raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0]

            # =========================================================================================
            # calculate loss
            # =========================================================================================

            # -------------
            # labeled
            # -------------
            # CE loss
            labeled_pred1 = labeled_logits['pred1']
            labeled_pred2 = labeled_logits['pred2']

            loss_CE1 = criterion_l(labeled_pred1, labeled_img_mask)
            loss_CE2 = criterion_l(labeled_pred2, labeled_img_mask)

            loss_CE = (loss_CE1 + loss_CE2) / 2
            loss_CE = loss_CE * args.w_CE

            # -------------
            # unlabeled
            # -------------
            # consistency loss
            raw_unlabeled_pred1 = raw_unlabeled_logits['pred1']
            raw_unlabeled_pred2 = raw_unlabeled_logits['pred2']

            cutmixed_aug_unlabeled_pred1 = cutmixed_aug_unlabeled_logits['pred1']
            cutmixed_aug_unlabeled_pred2 = cutmixed_aug_unlabeled_logits['pred2']

            if args.mode_confident == 'normal':
                loss_con1 = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * ((aug_unlabeled_pred1_for_mix.softmax(dim=1).max(dim=1)[0] > conf_threshold) & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con1 = torch.sum(loss_con1) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * ((aug_unlabeled_pred2_for_mix.softmax(dim=1).max(dim=1)[0] > conf_threshold) & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2 = torch.sum(loss_con2) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                loss_raw_con1 = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_raw_con1 = torch.sum(loss_raw_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * ((raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[0] > conf_threshold) & (ignore_img_mask != 255))
                loss_raw_con2 = torch.sum(loss_raw_con2) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'soft':
                confident_pred1, confident_pred2, unconfident_pred1, unconfident_pred2 = soft_label_selection(aug_unlabeled_pred1_for_mix, aug_unlabeled_pred2_for_mix, conf_threshold)

                loss_con1_confident = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (confident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_confident = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (confident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (unconfident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (unconfident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_confident) + torch.sum(loss_con1_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_confident) + torch.sum(loss_con2_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_confident_pred1, raw_confident_pred2, raw_unconfident_pred1, raw_unconfident_pred2 = soft_label_selection(raw_unlabeled_pred1, raw_unlabeled_pred2, conf_threshold)

                loss_raw_con1_confident = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_confident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_confident = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_confident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1_unconfident = 0.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_unconfident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_unconfident = 0.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_unconfident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_confident) + torch.sum(loss_raw_con1_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_confident) + torch.sum(loss_raw_con2_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote':
                same_pred, different_pred = vote_label_selection(unlabeled_pred1, unlabeled_pred2)

                loss_con1_same = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_same = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_different = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_different = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_same_pred, raw_different_pred = vote_label_selection(raw_unlabeled_pred1, raw_unlabeled_pred2)

                loss_raw_con1_same = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))
                loss_raw_con2_same = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))

                loss_raw_con1_different = 1.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_pred & (ignore_img_mask != 255))
                loss_raw_con2_different = 1.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_pred & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_same) + torch.sum(loss_raw_con1_different)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_same) + torch.sum(loss_raw_con2_different)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_threshold':
                different1_confident, different1_else, different2_confident, different2_else = vote_threshold_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_else = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different1_else & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_else = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different2_else & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_cc = args.w_confident * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different1_confident & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_cc = args.w_confident * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different2_confident & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_else) + torch.sum(loss_con1_cc)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_else) + torch.sum(loss_con2_cc)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_different1_confident, raw_different1_else, raw_different2_confident, raw_different2_else = vote_threshold_label_selection(raw_unlabeled_pred1, raw_unlabeled_pred2, conf_threshold)

                loss_raw_con1_else = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different1_else & (ignore_img_mask != 255))
                loss_raw_con2_else = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different2_else & (ignore_img_mask != 255))

                loss_raw_con1_cc = args.w_confident * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different1_confident & (ignore_img_mask != 255))
                loss_raw_con2_cc = args.w_confident * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different2_confident & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_else) + torch.sum(loss_raw_con1_cc)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_else) + torch.sum(loss_raw_con2_cc)) / torch.sum(ignore_img_mask != 255).item()

            elif args.mode_confident == 'vote_soft':
                same_pred, different_confident_pred1, different_confident_pred2, different_unconfident_pred1, different_unconfident_pred2 = vote_soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_con1_same = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_same = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (same_pred & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_different_confident = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different_confident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_different_confident = 1.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different_confident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1_different_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (different_unconfident_pred1 & (aug_unlabeled_ignore_img_mask_for_mix != 255))
                loss_con2_different_unconfident = 0.5 * criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (different_unconfident_pred2 & (aug_unlabeled_ignore_img_mask_for_mix != 255))

                loss_con1 = (torch.sum(loss_con1_same) + torch.sum(loss_con1_different_confident) + torch.sum(loss_con1_different_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = (torch.sum(loss_con2_same) + torch.sum(loss_con2_different_confident) + torch.sum(loss_con2_different_unconfident)) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                raw_same_pred, raw_different_confident_pred1, raw_different_confident_pred2, raw_different_unconfident_pred1, raw_different_unconfident_pred2 = vote_soft_label_selection(unlabeled_pred1, unlabeled_pred2, conf_threshold)

                loss_raw_con1_same = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))
                loss_raw_con2_same = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_same_pred & (ignore_img_mask != 255))

                loss_raw_con1_different_confident = 1.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_confident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_different_confident = 1.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_confident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1_different_unconfident = 0.5 * criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_unconfident_pred1 & (ignore_img_mask != 255))
                loss_raw_con2_different_unconfident = 0.5 * criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (raw_different_unconfident_pred2 & (ignore_img_mask != 255))

                loss_raw_con1 = (torch.sum(loss_raw_con1_same) + torch.sum(loss_raw_con1_different_confident) + torch.sum(loss_raw_con1_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = (torch.sum(loss_raw_con2_same) + torch.sum(loss_raw_con2_different_confident) + torch.sum(loss_raw_con2_different_unconfident)) / torch.sum(ignore_img_mask != 255).item()

            else:
                loss_con1 = criterion_u(cutmixed_aug_unlabeled_pred2, aug_unlabeled_pseudo_label1_for_mix) * (aug_unlabeled_ignore_img_mask_for_mix != 255)
                loss_con1 = torch.sum(loss_con1) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()
                loss_con2 = criterion_u(cutmixed_aug_unlabeled_pred1, aug_unlabeled_pseudo_label2_for_mix) * (aug_unlabeled_ignore_img_mask_for_mix != 255)
                loss_con2 = torch.sum(loss_con2) / torch.sum(aug_unlabeled_ignore_img_mask_for_mix != 255).item()

                loss_raw_con1 = criterion_u(raw_unlabeled_pred2, raw_unlabeled_logits['pred1'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_raw_con1 = torch.sum(loss_raw_con1) / torch.sum(ignore_img_mask != 255).item()
                loss_raw_con2 = criterion_u(raw_unlabeled_pred1, raw_unlabeled_logits['pred2'].softmax(dim=1).max(dim=1)[1].detach().long()) * (ignore_img_mask != 255)
                loss_raw_con2 = torch.sum(loss_raw_con2) / torch.sum(ignore_img_mask != 255).item()

            loss_con = (loss_con1 + loss_con2 + loss_raw_con1 + loss_raw_con2) / 4
            loss_con = loss_con * args.w_con

            # -------------
            # both
            # -------------
            # discrepancy loss
            cos_dis = nn.CosineSimilarity(dim=1, eps=1e-6)

            # labeled
            labeled_feature1 = labeled_logits['feature1']
            labeled_feature2 = labeled_logits['feature2']
            loss_dis_labeled1 = 1 + cos_dis(labeled_feature1.detach(), labeled_feature2).mean()
            loss_dis_labeled2 = 1 + cos_dis(labeled_feature2.detach(), labeled_feature1).mean()
            loss_dis_labeled = (loss_dis_labeled1 + loss_dis_labeled2) / 2

            # unlabeled
            cutmixed_aug_unlabeled_feature1 = cutmixed_aug_unlabeled_logits['feature1']
            cutmixed_aug_unlabeled_feature2 = cutmixed_aug_unlabeled_logits['feature2']
            loss_dis_cutmixed_aug_unlabeled1 = 1 + cos_dis(cutmixed_aug_unlabeled_feature1.detach(), cutmixed_aug_unlabeled_feature2).mean()
            loss_dis_cutmixed_aug_unlabeled2 = 1 + cos_dis(cutmixed_aug_unlabeled_feature2.detach(), cutmixed_aug_unlabeled_feature1).mean()
            loss_dis_cutmixed_aug_unlabeled = (loss_dis_cutmixed_aug_unlabeled1 + loss_dis_cutmixed_aug_unlabeled2) / 2

            raw_unlabeled_feature1 = raw_unlabeled_logits['feature1']
            raw_unlabeled_feature2 = raw_unlabeled_logits['feature2']
            loss_dis_raw_unlabeled1 = 1 + cos_dis(raw_unlabeled_feature1.detach(), raw_unlabeled_feature2).mean()
            loss_dis_raw_unlabeled2 = 1 + cos_dis(raw_unlabeled_feature2.detach(), raw_unlabeled_feature1).mean()
            loss_dis_raw_unlabeled = (loss_dis_raw_unlabeled1 + loss_dis_raw_unlabeled2) / 2

            loss_dis_unlabeled = (loss_dis_cutmixed_aug_unlabeled + loss_dis_raw_unlabeled) / 2

            loss_dis = (loss_dis_labeled + loss_dis_unlabeled) / 2
            loss_dis = loss_dis * args.w_dis

            # -------------
            # total
            # -------------
            loss = loss_CE + loss_con + loss_dis

            dist.barrier()

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_loss_CE += loss_CE.item()
            total_loss_con += loss_con.item()
            total_loss_dis += loss_dis.item()

            total_confident = (((unlabeled_pred_confidence1 >= 0.95) & (ignore_img_mask != 255)).sum().item() + ((unlabeled_pred_confidence2 >= 0.95) & (ignore_img_mask != 255)).sum().item()) / 2
            total_mask_ratio += total_confident / (ignore_img_mask != 255).sum().item()

            iters = epoch * len(trainloader_u) + i

            # update lr

            backbone_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            backbone_lr = backbone_lr * args.lr_backbone

            seg_lr = args.base_lr * (1 - iters / total_iters) ** args.mul_scheduler
            seg_lr = seg_lr * args.lr_network
                
            optimizer.param_groups[0]['lr'] = backbone_lr
            optimizer.param_groups[1]['lr'] = backbone_lr
            for ii in range(2, len(optimizer.param_groups)):
                optimizer.param_groups[ii]['lr'] = seg_lr

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                tb.add_scalar('train_loss_total', total_loss / (i+1), iters)
                tb.add_scalar('train_loss_CE', total_loss_CE / (i+1), iters)
                tb.add_scalar('train_loss_con', total_loss_con / (i+1), iters)
                tb.add_scalar('train_loss_dis', total_loss_dis / (i+1), iters)

            if (i % (len(trainloader_u) // 8) == 0) and (args.local_rank <= 0):
                logger.info('Iters: {:}, Total loss: {:.3f}, Loss CE: {:.3f}, '
                            'Loss consistency: {:.3f}, Loss discrepancy: {:.3f}, Mask: {:.3f}'.format(
                    i, total_loss / (i+1), total_loss_CE / (i+1), total_loss_con / (i+1), total_loss_dis / (i+1), 
                    total_mask_ratio / (i+1)))

        if args.use_SPL:
            conf_threshold += 0.01
            if conf_threshold >= 0.95:
                conf_threshold = 0.95

        if cfg['dataset'] == 'cityscapes':
            eval_mode = 'center_crop' if epoch < args.epochs - 20 else 'sliding_window'
        else:
            eval_mode = 'original'
        
        dist.barrier()

        # test with different branches  
        if args.local_rank <= 0:
            if epoch == 4:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=5)
            elif epoch == 9:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=10)
            elif epoch == 19:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=20)
            elif epoch == 39:
                evaluate_result = evaluate_save(cfg['dataset'], args.save_path, args.local_rank, model, valloader, eval_mode, args, cfg, idx_epoch=40)
            else:
                evaluate_result = evaluate(args.local_rank, model, valloader, eval_mode, args, cfg)

            mIOU1 = evaluate_result['IOU1']
            mIOU2 = evaluate_result['IOU2']
            mIOU_ave = evaluate_result['IOU_ave']

            tb.add_scalar('meanIOU_branch1', mIOU1, epoch)
            tb.add_scalar('meanIOU_branch2', mIOU2, epoch)
            tb.add_scalar('meanIOU_ave', mIOU_ave, epoch)

            logger.info('***** Evaluation with branch 1 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU1))
            logger.info('***** Evaluation with branch 2 {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU2))
            logger.info('***** Evaluation with two branches {} ***** >>>> meanIOU: {:.2f}\n'.format(eval_mode, mIOU_ave))

            if mIOU1 > previous_best1:
                if previous_best1 != 0:
                    os.remove(os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, previous_best1)))
                previous_best1 = mIOU1
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch1_%s_%.2f.pth' % (args.backbone, mIOU1)))
            
            if mIOU2 > previous_best2:
                if previous_best2 != 0:
                    os.remove(os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, previous_best2)))
                previous_best2 = mIOU2
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'branch2_%s_%.2f.pth' % (args.backbone, mIOU2)))

            if mIOU_ave > previous_best:
                if previous_best != 0:
                    os.remove(os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, previous_best)))
                previous_best = mIOU_ave
                torch.save(model.module.state_dict(),
                        os.path.join(args.save_path, 'ave_%s_%.2f.pth' % (args.backbone, mIOU_ave)))

Acho que você gosta

Origin blog.csdn.net/m0_61899108/article/details/131148390
Recomendado
Clasificación