[Anti-involución mano a mano] Cree una nueva segmentación audiovisual de tareas multimodal de IA: práctica de código, tutorial de optimización (2)

prefacio

Por favor, consulte el artículo anterior para la parte teórica:

Breve resumen: Necesitamos saber qué objeto en la imagen está emitiendo sonido de la siguiente demostración en video:

El gif no puede hacer ningún sonido, imaginemos que hay muchos autos en la escena, y solo este 120 está haciendo un sonido, por lo que el objeto que hace el sonido está segmentado.

 

 

 Esta es una escena en la que un cantante a veces canta y a veces toca el piano. Cuando solo toca el piano, el cuerpo humano no se divide. Cuando canta, el cuerpo humano se divide.

 

Introducción de la ruta relativa al código (mi versión, no oficial)

 

Puede descargar mi disco de red Baidu (con todos los datos y códigos), o puede descargar el código oficial, pero no contiene datos y solo se puede obtener mediante una aplicación.

tren

Primer vistazo a train.py

Consulte la ayuda del código a continuación.

parser.add_argument("--session_name", default="MS3", type=str, help="使用MS3是对数据里的Multi-sources下的数据进行训练,是多声源数据,也就是,可能同时有多个物体发声")
parser.add_argument("--visual_backbone", default="resnet", type=str,
                    help="use resnet50 or pvt-v2 as the visual backbone")
​
parser.add_argument("--train_batch_size", default=4, type=int)
parser.add_argument("--val_batch_size", default=1, type=int)
parser.add_argument("--max_epoches", default=5, type=int)
parser.add_argument("--lr", default=0.0001, type=float)
parser.add_argument("--num_workers", default=0, type=int)
parser.add_argument("--wt_dec", default=5-4, type=float)
​
parser.add_argument('--masked_av_flag', action='store_true', default=True,
                    help='使用作者论文里说的loss: sa/masked_va loss')
parser.add_argument("--lambda_1", default=0.5, type=float, help='均衡系数weight for balancing l4 loss')
parser.add_argument("--masked_av_stages", default=[0, 1, 2, 3], nargs='+', type=int,
                    help='作者的设置compute sa/masked_va loss in which stages: [0, 1, 2, 3]')
parser.add_argument('--threshold_flag', action='store_true', default=False,
                    help='whether thresholding the generated masks')
parser.add_argument("--mask_pooling_type", default='avg', type=str, help='the manner to downsample predicted masks')
parser.add_argument('--norm_fea_flag', action='store_true', default=False, help='音频标准化normalize audio-visual features')
parser.add_argument('--closer_flag', action='store_true', default=False, help='use closer loss for masked_va loss')
parser.add_argument('--euclidean_flag', action='store_true', default=False,
                    help='use euclidean distance for masked_va loss')
parser.add_argument('--kl_flag', action='store_true', default=True, help='KL散度 use kl loss for masked_va loss')
​
parser.add_argument("--load_s4_params", action='store_true', default=False,
                    help='use S4 parameters for initilization')
parser.add_argument("--trained_s4_model_path", type=str, default='', help='pretrained S4 model')
​
parser.add_argument("--tpavi_stages", default=[0, 1, 2, 3], nargs='+', type=int,
                    help='tpavi模块 add tpavi block in which stages: [0, 1, 2, 3]')
parser.add_argument("--tpavi_vv_flag", action='store_true', default=False, help='视觉自注意visual-visual self-attention')
parser.add_argument("--tpavi_va_flag", action='store_true', default=True, help='视听交叉注意visual-audio cross-attention')
​
parser.add_argument("--weights", type=str, default='', help='初始训练预训练模型,可以不写path of trained model')
parser.add_argument('--log_dir', default='./train_logs', type=str)

Todo el mundo puede entrenar según train.sh

detalles del código

A continuación, la columna vertebral se extraerá de acuerdo con las funciones visuales que desee, y las funciones de voz se extraerán de forma predeterminada mediante vggish.

if (args.visual_backbone).lower() == "resnet":
    from model import ResNet_AVSModel as AVSModel
​
    print('==> Use ResNet50 as the visual backbone...')
elif (args.visual_backbone).lower() == "pvt":
    from model import PVT_AVSModel as AVSModel
​
    print('==> Use pvt-v2 as the visual backbone...')
else:
    raise NotImplementedError("only support the resnet50 and pvt-v2")

Parte de lectura de datos:

class MS3Dataset(Dataset):
    """Dataset for multiple sound source segmentation"""
    def __init__(self, split='train'):
        super(MS3Dataset, self).__init__()
        self.split = split
        self.mask_num = 5
        df_all = pd.read_csv(cfg.DATA.ANNO_CSV, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        print("{}/{} videos are used for {}".format(len(self.df_split), len(df_all), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
​
​
​
    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name = df_one_video[0]
        img_base_path =  os.path.join(cfg.DATA.DIR_IMG, video_name)
        audio_lm_path = os.path.join(cfg.DATA.DIR_AUDIO_LOG_MEL, self.split, video_name + '.pkl')
        mask_base_path = os.path.join(cfg.DATA.DIR_MASK, self.split, video_name)
        audio_log_mel = load_audio_lm(audio_lm_path)
        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
        imgs, masks = [], []
        for img_id in range(1, 6):
            img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png"%(video_name, img_id)), transform=self.img_transform)
            imgs.append(img)
        for mask_id in range(1, self.mask_num + 1):
            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='P')
            masks.append(mask)
        imgs_tensor = torch.stack(imgs, dim=0)
        masks_tensor = torch.stack(masks, dim=0)
​
        return imgs_tensor, audio_log_mel, masks_tensor, video_name
​
    def __len__(self):
        return len(self.df_split)

Se puede ver que se leen 5 imágenes a la vez, y vi el video, todos los cuales duran 5 segundos, lo que indica que el autor entrena un video a la vez, y los cuadros por segundo de cada video se combinan con GT y voz para entrenamiento.

for n_iter, batch_data in enumerate(train_dataloader):
    imgs, audio, mask, _ = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5 or 1, 1, 224, 224]
​
    imgs = imgs.cuda()
    audio = audio.cuda()
    mask = mask.cuda()
    B, frame, C, H, W = imgs.shape
    imgs = imgs.view(B * frame, C, H, W)
    mask_num = 5
    mask = mask.view(B * mask_num, 1, H, W)
    audio = audio.view(-1, audio.shape[2], audio.shape[3], audio.shape[4])  # [B*T, 1, 96, 64]
    with torch.no_grad():
        audio_feature = audio_backbone(audio)  # [B*T, 128]
​
    output, v_map_list, a_fea_list = model(imgs, audio_feature)  # [bs*5, 1, 224, 224]
    loss, loss_dict = IouSemanticAwareLoss(output, mask, a_fea_list, v_map_list, \
                                           sa_loss_flag=args.masked_av_flag, lambda_1=args.lambda_1,
                                           count_stages=args.masked_av_stages, \
                                           mask_pooling_type=args.mask_pooling_type,
                                           threshold=args.threshold_flag, norm_fea=args.norm_fea_flag, \
                                           closer_flag=args.closer_flag, euclidean_flag=args.euclidean_flag,
                                           kl_flag=args.kl_flag)
​
    avg_meter_total_loss.add({'total_loss': loss.item()})
    avg_meter_iou_loss.add({'iou_loss': loss_dict['iou_loss']})
    avg_meter_sa_loss.add({'sa_loss': loss_dict['sa_loss']})
​
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
​
    global_step += 1
    if (global_step - 1) % 20 == 0:
        train_log = 'Iter:%5d/%5d, Total_Loss:%.4f, iou_loss:%.4f, sa_loss:%.4f, lr: %.4f' % (
            global_step - 1, max_step, avg_meter_total_loss.pop('total_loss'),
            avg_meter_iou_loss.pop('iou_loss'), avg_meter_sa_loss.pop('sa_loss'),
            optimizer.param_groups[0]['lr'])

Se puede ver que el entrenamiento es muy simple, primero se fusionan la imagen de carga y 5 marcos de vista, y luego se obtienen las características del habla y se envían al modelo. Luego calcule la pérdida y la puntuación Iou.

La entrada de datos al modelo se divide en dos partes, el cuadro de imagen [bs*5, 3, 224, 224], multiplicado por 5 significa que cada video tiene 5 cuadros, y la segunda parte es el cuadro de voz, con dimensiones similares .

class Pred_endecoder(nn.Module):
    # resnet based encoder decoder
    def __init__(self, channel=256, config=None, tpavi_stages=[], tpavi_vv_flag=False, tpavi_va_flag=True):
        super(Pred_endecoder, self).__init__()
        self.cfg = config
        self.tpavi_stages = tpavi_stages
        self.tpavi_vv_flag = tpavi_vv_flag
        self.tpavi_va_flag = tpavi_va_flag
​
        self.resnet = B2_ResNet()
        self.relu = nn.ReLU(inplace=True)
​
        self.conv4 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 2048)
        self.conv3 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 1024)
        self.conv2 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 512)
        self.conv1 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 256)
​
        self.path4 = FeatureFusionBlock(channel)
        self.path3 = FeatureFusionBlock(channel)
        self.path2 = FeatureFusionBlock(channel)
        self.path1 = FeatureFusionBlock(channel)
​
        for i in self.tpavi_stages:
            setattr(self, f"tpavi_b{i + 1}", TPAVIModule(in_channels=channel, mode='dot'))
            print("==> Build TPAVI block...")
​
        self.output_conv = nn.Sequential(
            nn.Conv2d(channel, 128, kernel_size=3, stride=1, padding=1),
            Interpolate(scale_factor=2, mode="bilinear"),
            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
        )
​
        if self.training:
            self.initialize_weights()
​
    def pre_reshape_for_tpavi(self, x):
        # x: [B*5, C, H, W]
        _, C, H, W = x.shape
        x = x.reshape(-1, 5, C, H, W)
        x = x.permute(0, 2, 1, 3, 4).contiguous()  # [B, C, T, H, W]
        return x
​
    def post_reshape_for_tpavi(self, x):
        # x: [B, C, T, H, W]
        # return: [B*T, C, H, W]
        _, C, _, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4)  # [B, T, C, H, W]
        x = x.view(-1, C, H, W)
        return x
​
    def tpavi_vv(self, x, stage):
        # x: visual, [B*T, C=256, H, W]
        tpavi_b = getattr(self, f'tpavi_b{stage + 1}')
        x = self.pre_reshape_for_tpavi(x)  # [B, C, T, H, W]
        x, _ = tpavi_b(x)  # [B, C, T, H, W]
        x = self.post_reshape_for_tpavi(x)  # [B*T, C, H, W]
        return x
​
    def tpavi_va(self, x, audio, stage):
        # x: visual, [B*T, C=256, H, W]
        # audio: [B*T, 128]
        # ra_flag: return audio feature list or not
        tpavi_b = getattr(self, f'tpavi_b{stage + 1}')
        audio = audio.view(-1, 5, audio.shape[-1])  # [B, T, 128]
        x = self.pre_reshape_for_tpavi(x)  # [B, C, T, H, W]
        x, a = tpavi_b(x, audio)  # [B, C, T, H, W], [B, T, C]
        x = self.post_reshape_for_tpavi(x)  # [B*T, C, H, W]
        return x, a
​
    def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
        return block(dilation_series, padding_series, NoLabels, input_channel)
​
    def forward(self, x, audio_feature=None):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        x1 = self.resnet.layer1(x)  # BF x 256  x 56 x 56
        x2 = self.resnet.layer2(x1)  # BF x 512  x 28 x 28
        x3 = self.resnet.layer3_1(x2)  # BF x 1024 x 14 x 14
        x4 = self.resnet.layer4_1(x3)  # BF x 2048 x  7 x  7
        # print(x1.shape, x2.shape, x3.shape, x4.shape)
​
        conv1_feat = self.conv1(x1)  # BF x 256 x 56 x 56
        conv2_feat = self.conv2(x2)  # BF x 256 x 28 x 28
        conv3_feat = self.conv3(x3)  # BF x 256 x 14 x 14
        conv4_feat = self.conv4(x4)  # BF x 256 x  7 x  7
        # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape)
​
        feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat]
        a_fea_list = [None] * 4
​
        if len(self.tpavi_stages) > 0:
            if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag):
                raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \
                    tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)')
            for i in self.tpavi_stages:
                tpavi_count = 0
                conv_feat = torch.zeros_like(feature_map_list[i]).cuda()
                if self.tpavi_vv_flag:
                    conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i)
                    conv_feat += conv_feat_vv
                    tpavi_count += 1
                if self.tpavi_va_flag:
                    conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i)
                    conv_feat += conv_feat_va
                    tpavi_count += 1
                    a_fea_list[i] = a_fea
                conv_feat /= tpavi_count
                feature_map_list[i] = conv_feat  # update features of stage-i which conduct TPAVI
​
        conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14
        conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
        conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
        conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
        # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
​
        pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224
        # print(pred.shape)
​
        return pred, feature_map_list, a_fea_list
​
    def initialize_weights(self):
        res50 = models.resnet50(pretrained=False)
        resnet50_dict = torch.load(self.cfg.TRAIN.PRETRAINED_RESNET50_PATH)
        res50.load_state_dict(resnet50_dict)
        pretrained_dict = res50.state_dict()
        # print(pretrained_dict.keys())
        all_params = {}
        for k, v in self.resnet.state_dict().items():
            if k in pretrained_dict.keys():
                v = pretrained_dict[k]
                all_params[k] = v
            elif '_1' in k:
                name = k.split('_1')[0] + k.split('_1')[1]
                v = pretrained_dict[name]
                all_params[k] = v
            elif '_2' in k:
                name = k.split('_2')[0] + k.split('_2')[1]
                v = pretrained_dict[name]
                all_params[k] = v
        assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
        self.resnet.load_state_dict(all_params)
        print(f'==> Load pretrained ResNet50 parameters from {self.cfg.TRAIN.PRETRAINED_RESNET50_PATH}')

La parte de la red es muy simple y la definición del modelo no tiene puntos brillantes. Veamos el código a continuación:

def forward(self, x, audio_feature=None):  #  输入图像帧和音频梅尔图经过vggish 的特征图。
    x = self.resnet.conv1(x)
    x = self.resnet.bn1(x)
    x = self.resnet.relu(x)
    x = self.resnet.maxpool(x)
    x1 = self.resnet.layer1(x)  # BF x 256  x 56 x 56
    x2 = self.resnet.layer2(x1)  # BF x 512  x 28 x 28
    x3 = self.resnet.layer3_1(x2)  # BF x 1024 x 14 x 14
    x4 = self.resnet.layer4_1(x3)  # BF x 2048 x  7 x  7  先进行resnet特征提取
    # print(x1.shape, x2.shape, x3.shape, x4.shape)
​
    conv1_feat = self.conv1(x1)  # BF x 256 x 56 x 56   维度转换一下
    conv2_feat = self.conv2(x2)  # BF x 256 x 28 x 28
    conv3_feat = self.conv3(x3)  # BF x 256 x 14 x 14
    conv4_feat = self.conv4(x4)  # BF x 256 x  7 x  7
    # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape)
​
    feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat]
    a_fea_list = [None] * 4
​
    if len(self.tpavi_stages) > 0:   # 做几次tpavi模块,论文中是4次
        if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag):
            raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \
                tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)')
        for i in self.tpavi_stages:
            tpavi_count = 0
            conv_feat = torch.zeros_like(feature_map_list[i]).cuda()
            if self.tpavi_vv_flag:
                conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i)
                conv_feat += conv_feat_vv
                tpavi_count += 1
            if self.tpavi_va_flag:
                # tpavi模块
                conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i)  
                conv_feat += conv_feat_va
                tpavi_count += 1
                a_fea_list[i] = a_fea
            conv_feat /= tpavi_count
            feature_map_list[i] = conv_feat  # update features of stage-i which conduct TPAVI
​
    conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14  # 解码
    conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
    conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
    conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
    # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
​
    pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224
    # print(pred.shape)
​
    return pred, feature_map_list, a_fea_list

Se puede ver que es bastante complicado pasar por un módulo TPAVI:

class TPAVIModule(nn.Module):
    def __init__(self, in_channels, inter_channels=None, mode='dot', 
                 dimension=3, bn_layer=True):
        """
        args:
            in_channels: original channel size (1024 in the paper)
            inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper)
            mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 
            dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal)
            bn_layer: whether to add batch norm
        """
        super(TPAVIModule, self).__init__()
​
        assert dimension in [1, 2, 3]
        
        if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']:
            raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`')
            
        self.mode = mode
        self.dimension = dimension
​
        self.in_channels = in_channels
        self.inter_channels = inter_channels
​
        # the channel size is reduced to half inside the block
        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1
        
        ## add align channel
        self.align_channel = nn.Linear(128, in_channels)
        self.norm_layer=nn.LayerNorm(in_channels)
​
        # assign appropriate convolutional, max pool, and batch norm layers for different dimensions
        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d
​
        # function g in the paper which goes through conv. with kernel size 1
        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
​
        if bn_layer:
            self.W_z = nn.Sequential(
                    conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
                    bn(self.in_channels)
                )
            nn.init.constant_(self.W_z[1].weight, 0)
            nn.init.constant_(self.W_z[1].bias, 0)
        else:
            self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)
​
            nn.init.constant_(self.W_z.weight, 0)
            nn.init.constant_(self.W_z.bias, 0)
​
        # define theta and phi for all operations except gaussian
        if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate":
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
        
        if self.mode == "concatenate":
            self.W_f = nn.Sequential(
                    nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1),
                    nn.ReLU()
                )
​
            
    def forward(self, x, audio=None):
        """
        args:
            x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1
            audio: (N, T, C)
        """
​
        audio_temp = 0
        batch_size, C = x.size(0), x.size(1)
        if audio is not None:
            # print('==> audio.shape', audio.shape)
            H, W = x.shape[-2], x.shape[-1]
            audio_temp = self.align_channel(audio) # [bs, T, C]
            audio = audio_temp.permute(0, 2, 1) # [bs, C, T]
            audio = audio.unsqueeze(-1).unsqueeze(-1) # [bs, C, T, 1, 1]
            audio = audio.repeat(1, 1, 1, H, W) # [bs, C, T, H, W]
        else:
            audio = x
​
        # (N, C, THW)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [bs, C, THW]
        # print('g_x.shape', g_x.shape)
        # g_x = x.view(batch_size, C, -1)  # [bs, C, THW]
        g_x = g_x.permute(0, 2, 1) # [bs, THW, C]
​
        if self.mode == "gaussian":
            theta_x = x.view(batch_size, self.in_channels, -1)
            phi_x = audio.view(batch_size, self.in_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            f = torch.matmul(theta_x, phi_x)
​
        elif self.mode == "embedded" or self.mode == "dot":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # [bs, C', THW]
            phi_x = self.phi(audio).view(batch_size, self.inter_channels, -1) # [bs, C', THW]
            theta_x = theta_x.permute(0, 2, 1) # [bs, THW, C']
            f = torch.matmul(theta_x, phi_x) # [bs, THW, THW]
​
        elif self.mode == "concatenate":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
            phi_x = self.phi(audio).view(batch_size, self.inter_channels, 1, -1)
            
            h = theta_x.size(2)
            w = phi_x.size(3)
            theta_x = theta_x.repeat(1, 1, 1, w)
            phi_x = phi_x.repeat(1, 1, h, 1)
            
            concat = torch.cat([theta_x, phi_x], dim=1)
            f = self.W_f(concat)
            f = f.view(f.size(0), f.size(2), f.size(3))
        
        if self.mode == "gaussian" or self.mode == "embedded":
            f_div_C = F.softmax(f, dim=-1)
        elif self.mode == "dot" or self.mode == "concatenate":
            N = f.size(-1) # number of position in x
            f_div_C = f / N  # [bs, THW, THW]
        
        y = torch.matmul(f_div_C, g_x) # [bs, THW, C]
        
        # contiguous here just allocates contiguous chunk of memory
        y = y.permute(0, 2, 1).contiguous() # [bs, C, THW]
        y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # [bs, C', T, H, W]
        
        W_y = self.W_z(y)  # [bs, C, T, H, W]
        # residual connection
        z = W_y + x #  # [bs, C, T, H, W]
​
        # add LayerNorm
        z =  z.permute(0, 2, 3, 4, 1) # [bs, T, H, W, C]
        z = self.norm_layer(z)
        z = z.permute(0, 4, 1, 2, 3) # [bs, C, T, H, W]
        
        return z, audio_temp

El código parece complicado. De hecho, el autor ha realizado una gran cantidad de selección de módulos y conversión de canales de código. La operación final real es unas pocas circunvoluciones 3D 1* 1 *1. No necesitamos pensar en ello. Las circunvoluciones 3D son se usa para hacer extractos de características de tiempo. Luego haz algunas operaciones de multiplicación y acumulación.

if dimension == 3:
    conv_nd = nn.Conv3d
    max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
    bn = nn.BatchNorm3d
elif dimension == 2:
    conv_nd = nn.Conv2d
    max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
    bn = nn.BatchNorm2d
else:
    conv_nd = nn.Conv1d
    max_pool_layer = nn.MaxPool1d(kernel_size=(2))
    bn = nn.BatchNorm1d

Finalmente, después de varios decodificadores, el mapa de características se convierte en una dimensión:

conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14
conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
# print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
​
pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224

Se puede ver que [BF x 1 x 224 x 224], el cambio unidimensional, es una parte de predicción de regresión de la red. La imagen bs *frame final de 1 * 224 * 224 es la imagen de salida final (se muestra como 0, 1 clasificación después de argmax y otras operaciones), y se convierte en la imagen de máscara predicha.

Puedes ver mi mapa de previsión:

 

prueba

Primero mire los datos de ms3_meta_data.csv

 

Como puede ver, hay tres conjuntos de datos: conjuntos de entrenamiento, verificación y prueba.Después de haber entrenado el modelo, podemos usar test.py para probar, y los resultados de la prueba se colocarán en la carpeta test_log. Irá a prueba, los datos en la carpeta de prueba. Ejecute el código de prueba y cambie la ruta del modelo entrenado para ver el resultado.

probar un video

Haga clic en raw_videos/ de avsbench_data/det/det y coloque los videos que desea probar, se recomienda 5s, porque necesita cortar 5 cuadros, a menos que cambie el código.

Luego ejecute preprocess_scripts/preprocess_ms3.py, que es para generar el gráfico Mel de la voz y corte los fotogramas, que se guardarán en el mismo nivel que raw_videos.

Luego ejecute detect.py (al mismo nivel que train.py) para razonar sobre su video.

Detección en tiempo real, todavía estoy escribiendo este código, espera un minuto.

Todos los enlaces del código (los archivos locales no se pueden cargar, solo se puede proporcionar el github original): https://github.com/OpenNLPLab/AVSBench

por fin

En un futuro cercano, grabaré un video, repasaré el principio, el código y el razonamiento de entrenamiento, todos presten atención ~

Supongo que te gusta

Origin blog.csdn.net/qq_46098574/article/details/126255334
Recomendado
Clasificación