[Hand in hand anti-involution] Create a new AI multi-modal task-audio-visual segmentation: code practice, optimization tutorial (2)

foreword

Please see the previous article for the theoretical part:

Brief overview: We need to know which object in the image is making sound as follows video demonstration:

The gif can’t make a sound. Let’s imagine that there are many cars in the scene, and only this 120 is making a sound, so the object that makes the sound is segmented.

 

 

 This is a scene where a singer sometimes sings and sometimes plays the piano. When only playing the piano, the human body is not divided. When singing, the human body is divided.

 

Code relative path introduction (my version, unofficial)

 

You can download my Baidu network disk (with all the data and codes), or you can download the official code, but it does not contain data and can only be obtained by application.

train

First look at train.py

See the help of the code below.

parser.add_argument("--session_name", default="MS3", type=str, help="使用MS3是对数据里的Multi-sources下的数据进行训练,是多声源数据,也就是,可能同时有多个物体发声")
parser.add_argument("--visual_backbone", default="resnet", type=str,
                    help="use resnet50 or pvt-v2 as the visual backbone")
​
parser.add_argument("--train_batch_size", default=4, type=int)
parser.add_argument("--val_batch_size", default=1, type=int)
parser.add_argument("--max_epoches", default=5, type=int)
parser.add_argument("--lr", default=0.0001, type=float)
parser.add_argument("--num_workers", default=0, type=int)
parser.add_argument("--wt_dec", default=5-4, type=float)
​
parser.add_argument('--masked_av_flag', action='store_true', default=True,
                    help='使用作者论文里说的loss: sa/masked_va loss')
parser.add_argument("--lambda_1", default=0.5, type=float, help='均衡系数weight for balancing l4 loss')
parser.add_argument("--masked_av_stages", default=[0, 1, 2, 3], nargs='+', type=int,
                    help='作者的设置compute sa/masked_va loss in which stages: [0, 1, 2, 3]')
parser.add_argument('--threshold_flag', action='store_true', default=False,
                    help='whether thresholding the generated masks')
parser.add_argument("--mask_pooling_type", default='avg', type=str, help='the manner to downsample predicted masks')
parser.add_argument('--norm_fea_flag', action='store_true', default=False, help='音频标准化normalize audio-visual features')
parser.add_argument('--closer_flag', action='store_true', default=False, help='use closer loss for masked_va loss')
parser.add_argument('--euclidean_flag', action='store_true', default=False,
                    help='use euclidean distance for masked_va loss')
parser.add_argument('--kl_flag', action='store_true', default=True, help='KL散度 use kl loss for masked_va loss')
​
parser.add_argument("--load_s4_params", action='store_true', default=False,
                    help='use S4 parameters for initilization')
parser.add_argument("--trained_s4_model_path", type=str, default='', help='pretrained S4 model')
​
parser.add_argument("--tpavi_stages", default=[0, 1, 2, 3], nargs='+', type=int,
                    help='tpavi模块 add tpavi block in which stages: [0, 1, 2, 3]')
parser.add_argument("--tpavi_vv_flag", action='store_true', default=False, help='视觉自注意visual-visual self-attention')
parser.add_argument("--tpavi_va_flag", action='store_true', default=True, help='视听交叉注意visual-audio cross-attention')
​
parser.add_argument("--weights", type=str, default='', help='初始训练预训练模型,可以不写path of trained model')
parser.add_argument('--log_dir', default='./train_logs', type=str)

Everyone can train according to train.sh

code details

Next, the backbone will be extracted according to the visual features you want, and the voice features will be extracted by default using vggish.

if (args.visual_backbone).lower() == "resnet":
    from model import ResNet_AVSModel as AVSModel
​
    print('==> Use ResNet50 as the visual backbone...')
elif (args.visual_backbone).lower() == "pvt":
    from model import PVT_AVSModel as AVSModel
​
    print('==> Use pvt-v2 as the visual backbone...')
else:
    raise NotImplementedError("only support the resnet50 and pvt-v2")

Data reading part:

class MS3Dataset(Dataset):
    """Dataset for multiple sound source segmentation"""
    def __init__(self, split='train'):
        super(MS3Dataset, self).__init__()
        self.split = split
        self.mask_num = 5
        df_all = pd.read_csv(cfg.DATA.ANNO_CSV, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        print("{}/{} videos are used for {}".format(len(self.df_split), len(df_all), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
​
​
​
    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name = df_one_video[0]
        img_base_path =  os.path.join(cfg.DATA.DIR_IMG, video_name)
        audio_lm_path = os.path.join(cfg.DATA.DIR_AUDIO_LOG_MEL, self.split, video_name + '.pkl')
        mask_base_path = os.path.join(cfg.DATA.DIR_MASK, self.split, video_name)
        audio_log_mel = load_audio_lm(audio_lm_path)
        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
        imgs, masks = [], []
        for img_id in range(1, 6):
            img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png"%(video_name, img_id)), transform=self.img_transform)
            imgs.append(img)
        for mask_id in range(1, self.mask_num + 1):
            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='P')
            masks.append(mask)
        imgs_tensor = torch.stack(imgs, dim=0)
        masks_tensor = torch.stack(masks, dim=0)
​
        return imgs_tensor, audio_log_mel, masks_tensor, video_name
​
    def __len__(self):
        return len(self.df_split)

It can be seen that 5 pictures are read at a time, and I watched the video, all of which are 5 seconds long, indicating that the author trains one video at a time, and the frames per second of each video are combined with GT and voice for training.

for n_iter, batch_data in enumerate(train_dataloader):
    imgs, audio, mask, _ = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5 or 1, 1, 224, 224]
​
    imgs = imgs.cuda()
    audio = audio.cuda()
    mask = mask.cuda()
    B, frame, C, H, W = imgs.shape
    imgs = imgs.view(B * frame, C, H, W)
    mask_num = 5
    mask = mask.view(B * mask_num, 1, H, W)
    audio = audio.view(-1, audio.shape[2], audio.shape[3], audio.shape[4])  # [B*T, 1, 96, 64]
    with torch.no_grad():
        audio_feature = audio_backbone(audio)  # [B*T, 128]
​
    output, v_map_list, a_fea_list = model(imgs, audio_feature)  # [bs*5, 1, 224, 224]
    loss, loss_dict = IouSemanticAwareLoss(output, mask, a_fea_list, v_map_list, \
                                           sa_loss_flag=args.masked_av_flag, lambda_1=args.lambda_1,
                                           count_stages=args.masked_av_stages, \
                                           mask_pooling_type=args.mask_pooling_type,
                                           threshold=args.threshold_flag, norm_fea=args.norm_fea_flag, \
                                           closer_flag=args.closer_flag, euclidean_flag=args.euclidean_flag,
                                           kl_flag=args.kl_flag)
​
    avg_meter_total_loss.add({'total_loss': loss.item()})
    avg_meter_iou_loss.add({'iou_loss': loss_dict['iou_loss']})
    avg_meter_sa_loss.add({'sa_loss': loss_dict['sa_loss']})
​
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
​
    global_step += 1
    if (global_step - 1) % 20 == 0:
        train_log = 'Iter:%5d/%5d, Total_Loss:%.4f, iou_loss:%.4f, sa_loss:%.4f, lr: %.4f' % (
            global_step - 1, max_step, avg_meter_total_loss.pop('total_loss'),
            avg_meter_iou_loss.pop('iou_loss'), avg_meter_sa_loss.pop('sa_loss'),
            optimizer.param_groups[0]['lr'])

It can be seen that the training is very simple. First, the load image and 5 frames of view are merged together, and then the speech features are obtained and sent to the model. Then calculate the loss and Iou score.

The data input to the model is divided into two parts, the image frame [bs*5, 3, 224, 224], multiplied by 5 means that each video has 5 frames, and the second part is the voice frame, with similar dimensions.

class Pred_endecoder(nn.Module):
    # resnet based encoder decoder
    def __init__(self, channel=256, config=None, tpavi_stages=[], tpavi_vv_flag=False, tpavi_va_flag=True):
        super(Pred_endecoder, self).__init__()
        self.cfg = config
        self.tpavi_stages = tpavi_stages
        self.tpavi_vv_flag = tpavi_vv_flag
        self.tpavi_va_flag = tpavi_va_flag
​
        self.resnet = B2_ResNet()
        self.relu = nn.ReLU(inplace=True)
​
        self.conv4 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 2048)
        self.conv3 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 1024)
        self.conv2 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 512)
        self.conv1 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 256)
​
        self.path4 = FeatureFusionBlock(channel)
        self.path3 = FeatureFusionBlock(channel)
        self.path2 = FeatureFusionBlock(channel)
        self.path1 = FeatureFusionBlock(channel)
​
        for i in self.tpavi_stages:
            setattr(self, f"tpavi_b{i + 1}", TPAVIModule(in_channels=channel, mode='dot'))
            print("==> Build TPAVI block...")
​
        self.output_conv = nn.Sequential(
            nn.Conv2d(channel, 128, kernel_size=3, stride=1, padding=1),
            Interpolate(scale_factor=2, mode="bilinear"),
            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
        )
​
        if self.training:
            self.initialize_weights()
​
    def pre_reshape_for_tpavi(self, x):
        # x: [B*5, C, H, W]
        _, C, H, W = x.shape
        x = x.reshape(-1, 5, C, H, W)
        x = x.permute(0, 2, 1, 3, 4).contiguous()  # [B, C, T, H, W]
        return x
​
    def post_reshape_for_tpavi(self, x):
        # x: [B, C, T, H, W]
        # return: [B*T, C, H, W]
        _, C, _, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4)  # [B, T, C, H, W]
        x = x.view(-1, C, H, W)
        return x
​
    def tpavi_vv(self, x, stage):
        # x: visual, [B*T, C=256, H, W]
        tpavi_b = getattr(self, f'tpavi_b{stage + 1}')
        x = self.pre_reshape_for_tpavi(x)  # [B, C, T, H, W]
        x, _ = tpavi_b(x)  # [B, C, T, H, W]
        x = self.post_reshape_for_tpavi(x)  # [B*T, C, H, W]
        return x
​
    def tpavi_va(self, x, audio, stage):
        # x: visual, [B*T, C=256, H, W]
        # audio: [B*T, 128]
        # ra_flag: return audio feature list or not
        tpavi_b = getattr(self, f'tpavi_b{stage + 1}')
        audio = audio.view(-1, 5, audio.shape[-1])  # [B, T, 128]
        x = self.pre_reshape_for_tpavi(x)  # [B, C, T, H, W]
        x, a = tpavi_b(x, audio)  # [B, C, T, H, W], [B, T, C]
        x = self.post_reshape_for_tpavi(x)  # [B*T, C, H, W]
        return x, a
​
    def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
        return block(dilation_series, padding_series, NoLabels, input_channel)
​
    def forward(self, x, audio_feature=None):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        x1 = self.resnet.layer1(x)  # BF x 256  x 56 x 56
        x2 = self.resnet.layer2(x1)  # BF x 512  x 28 x 28
        x3 = self.resnet.layer3_1(x2)  # BF x 1024 x 14 x 14
        x4 = self.resnet.layer4_1(x3)  # BF x 2048 x  7 x  7
        # print(x1.shape, x2.shape, x3.shape, x4.shape)
​
        conv1_feat = self.conv1(x1)  # BF x 256 x 56 x 56
        conv2_feat = self.conv2(x2)  # BF x 256 x 28 x 28
        conv3_feat = self.conv3(x3)  # BF x 256 x 14 x 14
        conv4_feat = self.conv4(x4)  # BF x 256 x  7 x  7
        # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape)
​
        feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat]
        a_fea_list = [None] * 4
​
        if len(self.tpavi_stages) > 0:
            if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag):
                raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \
                    tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)')
            for i in self.tpavi_stages:
                tpavi_count = 0
                conv_feat = torch.zeros_like(feature_map_list[i]).cuda()
                if self.tpavi_vv_flag:
                    conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i)
                    conv_feat += conv_feat_vv
                    tpavi_count += 1
                if self.tpavi_va_flag:
                    conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i)
                    conv_feat += conv_feat_va
                    tpavi_count += 1
                    a_fea_list[i] = a_fea
                conv_feat /= tpavi_count
                feature_map_list[i] = conv_feat  # update features of stage-i which conduct TPAVI
​
        conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14
        conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
        conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
        conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
        # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
​
        pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224
        # print(pred.shape)
​
        return pred, feature_map_list, a_fea_list
​
    def initialize_weights(self):
        res50 = models.resnet50(pretrained=False)
        resnet50_dict = torch.load(self.cfg.TRAIN.PRETRAINED_RESNET50_PATH)
        res50.load_state_dict(resnet50_dict)
        pretrained_dict = res50.state_dict()
        # print(pretrained_dict.keys())
        all_params = {}
        for k, v in self.resnet.state_dict().items():
            if k in pretrained_dict.keys():
                v = pretrained_dict[k]
                all_params[k] = v
            elif '_1' in k:
                name = k.split('_1')[0] + k.split('_1')[1]
                v = pretrained_dict[name]
                all_params[k] = v
            elif '_2' in k:
                name = k.split('_2')[0] + k.split('_2')[1]
                v = pretrained_dict[name]
                all_params[k] = v
        assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
        self.resnet.load_state_dict(all_params)
        print(f'==> Load pretrained ResNet50 parameters from {self.cfg.TRAIN.PRETRAINED_RESNET50_PATH}')

The network part is very simple, and the definition of the model has no bright spots. Let's look at the code in forward:

def forward(self, x, audio_feature=None):  #  输入图像帧和音频梅尔图经过vggish 的特征图。
    x = self.resnet.conv1(x)
    x = self.resnet.bn1(x)
    x = self.resnet.relu(x)
    x = self.resnet.maxpool(x)
    x1 = self.resnet.layer1(x)  # BF x 256  x 56 x 56
    x2 = self.resnet.layer2(x1)  # BF x 512  x 28 x 28
    x3 = self.resnet.layer3_1(x2)  # BF x 1024 x 14 x 14
    x4 = self.resnet.layer4_1(x3)  # BF x 2048 x  7 x  7  先进行resnet特征提取
    # print(x1.shape, x2.shape, x3.shape, x4.shape)
​
    conv1_feat = self.conv1(x1)  # BF x 256 x 56 x 56   维度转换一下
    conv2_feat = self.conv2(x2)  # BF x 256 x 28 x 28
    conv3_feat = self.conv3(x3)  # BF x 256 x 14 x 14
    conv4_feat = self.conv4(x4)  # BF x 256 x  7 x  7
    # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape)
​
    feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat]
    a_fea_list = [None] * 4
​
    if len(self.tpavi_stages) > 0:   # 做几次tpavi模块,论文中是4次
        if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag):
            raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \
                tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)')
        for i in self.tpavi_stages:
            tpavi_count = 0
            conv_feat = torch.zeros_like(feature_map_list[i]).cuda()
            if self.tpavi_vv_flag:
                conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i)
                conv_feat += conv_feat_vv
                tpavi_count += 1
            if self.tpavi_va_flag:
                # tpavi模块
                conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i)  
                conv_feat += conv_feat_va
                tpavi_count += 1
                a_fea_list[i] = a_fea
            conv_feat /= tpavi_count
            feature_map_list[i] = conv_feat  # update features of stage-i which conduct TPAVI
​
    conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14  # 解码
    conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
    conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
    conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
    # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
​
    pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224
    # print(pred.shape)
​
    return pred, feature_map_list, a_fea_list

It can be seen that it is quite complicated to go through a TPAVI module:

class TPAVIModule(nn.Module):
    def __init__(self, in_channels, inter_channels=None, mode='dot', 
                 dimension=3, bn_layer=True):
        """
        args:
            in_channels: original channel size (1024 in the paper)
            inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper)
            mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 
            dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal)
            bn_layer: whether to add batch norm
        """
        super(TPAVIModule, self).__init__()
​
        assert dimension in [1, 2, 3]
        
        if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']:
            raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`')
            
        self.mode = mode
        self.dimension = dimension
​
        self.in_channels = in_channels
        self.inter_channels = inter_channels
​
        # the channel size is reduced to half inside the block
        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1
        
        ## add align channel
        self.align_channel = nn.Linear(128, in_channels)
        self.norm_layer=nn.LayerNorm(in_channels)
​
        # assign appropriate convolutional, max pool, and batch norm layers for different dimensions
        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d
​
        # function g in the paper which goes through conv. with kernel size 1
        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
​
        if bn_layer:
            self.W_z = nn.Sequential(
                    conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
                    bn(self.in_channels)
                )
            nn.init.constant_(self.W_z[1].weight, 0)
            nn.init.constant_(self.W_z[1].bias, 0)
        else:
            self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)
​
            nn.init.constant_(self.W_z.weight, 0)
            nn.init.constant_(self.W_z.bias, 0)
​
        # define theta and phi for all operations except gaussian
        if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate":
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
        
        if self.mode == "concatenate":
            self.W_f = nn.Sequential(
                    nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1),
                    nn.ReLU()
                )
​
            
    def forward(self, x, audio=None):
        """
        args:
            x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1
            audio: (N, T, C)
        """
​
        audio_temp = 0
        batch_size, C = x.size(0), x.size(1)
        if audio is not None:
            # print('==> audio.shape', audio.shape)
            H, W = x.shape[-2], x.shape[-1]
            audio_temp = self.align_channel(audio) # [bs, T, C]
            audio = audio_temp.permute(0, 2, 1) # [bs, C, T]
            audio = audio.unsqueeze(-1).unsqueeze(-1) # [bs, C, T, 1, 1]
            audio = audio.repeat(1, 1, 1, H, W) # [bs, C, T, H, W]
        else:
            audio = x
​
        # (N, C, THW)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [bs, C, THW]
        # print('g_x.shape', g_x.shape)
        # g_x = x.view(batch_size, C, -1)  # [bs, C, THW]
        g_x = g_x.permute(0, 2, 1) # [bs, THW, C]
​
        if self.mode == "gaussian":
            theta_x = x.view(batch_size, self.in_channels, -1)
            phi_x = audio.view(batch_size, self.in_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            f = torch.matmul(theta_x, phi_x)
​
        elif self.mode == "embedded" or self.mode == "dot":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # [bs, C', THW]
            phi_x = self.phi(audio).view(batch_size, self.inter_channels, -1) # [bs, C', THW]
            theta_x = theta_x.permute(0, 2, 1) # [bs, THW, C']
            f = torch.matmul(theta_x, phi_x) # [bs, THW, THW]
​
        elif self.mode == "concatenate":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
            phi_x = self.phi(audio).view(batch_size, self.inter_channels, 1, -1)
            
            h = theta_x.size(2)
            w = phi_x.size(3)
            theta_x = theta_x.repeat(1, 1, 1, w)
            phi_x = phi_x.repeat(1, 1, h, 1)
            
            concat = torch.cat([theta_x, phi_x], dim=1)
            f = self.W_f(concat)
            f = f.view(f.size(0), f.size(2), f.size(3))
        
        if self.mode == "gaussian" or self.mode == "embedded":
            f_div_C = F.softmax(f, dim=-1)
        elif self.mode == "dot" or self.mode == "concatenate":
            N = f.size(-1) # number of position in x
            f_div_C = f / N  # [bs, THW, THW]
        
        y = torch.matmul(f_div_C, g_x) # [bs, THW, C]
        
        # contiguous here just allocates contiguous chunk of memory
        y = y.permute(0, 2, 1).contiguous() # [bs, C, THW]
        y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # [bs, C', T, H, W]
        
        W_y = self.W_z(y)  # [bs, C, T, H, W]
        # residual connection
        z = W_y + x #  # [bs, C, T, H, W]
​
        # add LayerNorm
        z =  z.permute(0, 2, 3, 4, 1) # [bs, T, H, W, C]
        z = self.norm_layer(z)
        z = z.permute(0, 4, 1, 2, 3) # [bs, C, T, H, W]
        
        return z, audio_temp

The code looks complicated. In fact, the author has done a lot of module selection and code channel conversion. The actual final operation is a few 1* 1 *1 3D convolutions. We don’t need to think about it. 3d convolutions are used to make timing features extract. Then do some multiply-accumulate operations.

if dimension == 3:
    conv_nd = nn.Conv3d
    max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
    bn = nn.BatchNorm3d
elif dimension == 2:
    conv_nd = nn.Conv2d
    max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
    bn = nn.BatchNorm2d
else:
    conv_nd = nn.Conv1d
    max_pool_layer = nn.MaxPool1d(kernel_size=(2))
    bn = nn.BatchNorm1d

Finally, after several decoders, the feature map is converted into one dimension:

conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14
conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
# print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
​
pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224

It can be seen that [BF x 1 x 224 x 224], the 1-dimensional change, is a regression prediction part of the network. The final bs *frame image of 1 * 224 *224 is the final output image (displayed as 0, 1 classification after argmax and other operations), and it becomes the predicted mask image.

You can see my forecast map:

 

test

First look at the data of ms3_meta_data.csv

 

As you can see, there are three sets of data: training, verification and test sets. After we have trained the model, we can use test.py to test, and the test results will be placed in the test_log folder. Will go to test, the data in the test folder. Run the test code and change the path of the trained model to see the result.

test a video

Click on the raw_videos/ of avsbench_data/det/det and put the videos you want to test. It is recommended to 5s, because you need to cut 5 frames, unless you change the code.

Then run preprocess_scripts/preprocess_ms3.py, which is to generate the Mel graph of the voice, and cut frames, which will be saved to the same level as raw_videos.

Then run detect.py (at the same level as train.py) to reason about your video.

Real-time detection, I am still writing this code, wait a minute.

All links of the code (local files cannot be uploaded, only the original github can be provided): https://github.com/OpenNLPLab/AVSBench

at last

In the near future, I will record a video, go through the principle, code and training reasoning, everyone pay attention~

Guess you like

Origin blog.csdn.net/qq_46098574/article/details/126255334