TSM (Temporal Shift Module) source code analysis

TSM (Temporal Shift Module) source code analysis

论文名:TSM: Temporal Shift Module for Efficient Video Understanding

Code link: https://github.com/mit-han-lab/temporal-shift-module

The main structure of the code is as follows:

python file function
mian.py Main training function
opts.py Code parameter configuration
ops/dataset.py The core of the loading part of the data set is the __getitem__ function.
ops/dataset_config.py Used to configure different data sets
ops/models.py Assembly model
ops/temporal_shift.py The core temporal shift operation

1.opts.py is the parameter configuration. In addition to the path and hyperparameters, there are several parameters to pay attention to (different from TSN):

parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')

– Modality indicates the type of input, RGB/FLOW

–Num_segments represents the number of frames per video sample, generally 8 or 16.

-Shift indicates whether to add the tsm module.

-Shift_div represents the ratio of shift features, generally 8. It means that the features of 2*1/8 scale will move, of which 1/8 features will be shift left, and the other 1/8 features will be shift right.

2.dataset_config.py is the data-level configuration.

Each data set implements a return_xxx(modality)

Returns information such as the subclass name supported by the data set, train_list path, val_list path, and root path of the data set.

3.dataset.py is the loading part of the dataset.

The main function of dataset.py is to read the dataset, sample it sparsely, and return the dataset obtained after sparse sampling.

The TSNDataSet class is implemented in dataset.py to process raw data. This class is inherited from the torch.utils.data.dataset class.

Which first defines a simple class:

3.1 VideoRecord, used to encapsulate a video content, including the path of the picture, the number of frames, and label information.

class VideoRecord(object):
    def __init__(self, row):
        self._data = row

    @property
    def path(self):
        return self._data[0]

    @property
    def num_frames(self):
        return int(self._data[1])

    @property
    def label(self):
        return int(self._data[2])

3.2 TSNDataSet:

class TSNDataSet(data.Dataset):
    def __init__(self, root_path, list_file,
                 num_segments=3, new_length=1, modality='RGB',
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 random_shift=True, test_mode=False,
                 remove_missing=False, dense_sample=False, twice_sample=False):

        self.root_path = root_path
        self.list_file = list_file
        self.num_segments = num_segments
        self.new_length = new_length
        self.modality = modality
        self.image_tmpl = image_tmpl
        self.transform = transform
        self.random_shift = random_shift
        self.test_mode = test_mode
        self.remove_missing = remove_missing
        self.dense_sample = dense_sample  # using dense sample as I3D
        self.twice_sample = twice_sample  # twice sample for more validation
        if self.dense_sample:
            print('=> Using dense sample for the dataset...')
        if self.twice_sample:
            print('=> Using twice sample for the dataset...')

        if self.modality == 'RGBDiff':
            self.new_length += 1  # Diff needs one more image to calculate diff

        self._parse_list()

After setting some parameters and their default values, the _parse_list() function is called, and tmq is a list with the length of the training data. Each value is a VIDEORecord object, including a list and 3 attributes. The length of the list is 3, separated by the space bar, which are the frame path, how many frames the video contains, and the frame label. Then call the VideoRecord() function to write the content into a VideoRecord list.

 tmp = [x.strip().split(' ') for x in open(self.list_file)]
self.video_list = [VideoRecord(item) for item in tmp]

3.3 TSNDataSet needs to implement the core function __getitem()

First, we need to get the image and label corresponding to the video. The format of the image is [n*t, c, h, w]

    def __getitem__(self, index):
        record = self.video_list[index]

So for a video, how to get num_segments frames?

        if not self.test_mode:  # test_mode: False; random_shift: True;
            segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
        else:
            segment_indices = self._get_test_indices(record)
        return self.get(record, segment_indices)

There are 3 different sampling functions (for train, test, val):

Among them, the _sample_indice function is for the sampling method of train; by default, the _sample_indice function will randomly take num_segments index, and there are two methods: dense sampling (dense) and sparse sampling (normal).

_get_test_indices will get num_segments index sparsely and fixedly.

To give a simple example, when num_frames=120, num_segments=3,

The normal sample in _sample_indices will randomly return: [4, 44, 84], [5, 45, 85], [11, 51, 91].

The dense sample in _sample_indices will randomly return: [15, 36, 57], [30, 51, 72], [44, 65, 86].

The same applies to dense_sample in _get_test_indices

The twice_sample in _get_test_indices will randomly return: [11,31,51,1,21,41]

Finally, each video samples num_segments frames, and the dimensions returned by getitem are: [n * t, c, h, w] (frames), 1 (label).

4.models.py

The main function of models.py is to prepare for the subsequent training model; first use some classic models as the basis, such as resnet50, for different input modalities, modify the last fully connected layer to get our TSN model, and It also introduces whether to add the TSM module, so as to get the TSM model we need.

The init function sets some parameters and parameter default values. The TSN model is obtained by calling the function to modify the model, and the TSM module function is called to obtain the TSM model. The init function is called

1 Call _prepare_base_model(base_model) to build a basic model

2 Call _prepare_tsn(num_class) to adapt the size of the fc layer according to the number of subclasses of different data sets

3 For the input of flow and rgbdiff, call _construct_flow_model and _construct_diff_model to change the size of the first convolution kernel

4.1 _prepare_base_model() function

And call the make_temporal_shit() function to add the tsm module:

    def _prepare_base_model(self, base_model):
        print('=> base model: {}'.format(base_model))

        if 'resnet' in base_model:
            # torchvision.models.resnet50
            self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False)
            if self.is_shift:  # 默认false
                print('Adding temporal shift...')
                from ops.temporal_shift import make_temporal_shift
                make_temporal_shift(self.base_model, self.num_segments,
                                    n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)

            if self.non_local:  # 默认False
                print('Adding non-local module...')
                from ops.non_local import make_non_local
                make_non_local(self.base_model, self.num_segments)

            self.base_model.last_layer_name = 'fc'
            self.input_size = 224
            self.input_mean = [0.485, 0.456, 0.406]
            self.input_std = [0.229, 0.224, 0.225]
		   # 1代表输出特征图的大小;torch.Size([2, 32, 16, 16])----->torch.Size([2, 32, 1, 1])
            self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)  

4.2 _prepare_tsn function

The function of the _prepare_tsn function is to modify the known basemodel network structure and fine-tune the structure of the last layer (fully connected layer) to become a form suitable for the output of the data set.

    def _prepare_tsn(self, num_class):
        # 获取模型最后一层输入层的维度
        feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
        # 如果dropout==0,直接添加新的全连接层,输出维度是num_class
        if self.dropout == 0:
            setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
            self.new_fc = None
        # 如果有dropout!=0,添加dropout层,然后再添加全连接层,输出维度是num_class
        else:
            setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
            self.new_fc = nn.Linear(feature_dim, num_class)

        std = 0.001
        if self.new_fc is None:
            normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
            constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
        else:
            if hasattr(self.new_fc, 'weight'):
                normal_(self.new_fc.weight, 0, std)
                constant_(self.new_fc.bias, 0)
        return feature_dim

4.3 The core of the make_temporal_shift() function TSM module

make_temporal_shift(self.base_model, self.num_segments,
                                    n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)

So how is TemporalShift(b, n_segment=this_segment, n_div=n_div implemented?

For the input of [n*t, c, h, w], t is the value of the segment, first reshape into [n, t, c, h, w],

If the channel of the current feature map is 256 and fold_div=8, then 256/8 features are shifted left, and 256/8 features are shifted right. The features of the other part remain unchanged.

@staticmethod
    def shift(x, n_segment, fold_div=3, inplace=False):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w)

        fold = c // fold_div
        if inplace:
            # Due to some out of order error when performing parallel computing. 
            # May need to write a CUDA kernel.
            raise NotImplementedError  
            # out = InplaceShift.apply(x, fold)
        else:
            out = torch.zeros_like(x)
            out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
            out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
            out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift

        return out.view(nt, c, h, w)

Give a simple example:

When c = 8, num_segment=4, the features of 2 dimensions are expressed as follows:

0_xx represents the characteristics of the first frame. 1_xx codes the features of the second frame, each feature has 8 channels. The original features are as follows:

preview

When fold_div = 8, the moving is as follows:

img

It can be seen that the features of the second frame are incorporated into the first frame, and the features of the third and second frames are incorporated into the second frame.

When fold_div=4, there will be more moving parts, that is, the features of the current frame will contain more information about the previous and next frames.

img

5.main.py is the main training function

Finally, we will explain the training main function, and connect the above-mentioned classes and functions in series.

5.1 Call dataset_config.return_dataset to obtain information about each path

5.2 Instantiate a good model and optimizer

5.3 Load the pre-trained model or recover from training

5.4 Prepare data train_loader, val_loader

5.5 Then start each epoch iteration, save a newest model and the best model

adjust_learning_rate, adjust the learning rate of each epoch according to the strategy.

    for epoch in range(args.start_epoch, args.epochs):
        # Sets the learning rate to the initial LR decayed by 10 every 30 epochs
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()  # 刷新缓冲区

            save_checkpoint({
    
    
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)

The core functions of train are as follows:

def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    if args.no_partialbn:  # 默认false
        model.module.partialBN(False)
    else:
        model.module.partialBN(True)

    # switch to train mode
    model.train()

    end = time.time()  # 返回当前时间戳
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        raise RuntimeError

        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # print('********************')
        # print(loss)
        # raise RuntimeError

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()

        if args.clip_gradient is not None:  # None
            total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)

        optimizer.step()
        optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:   # print_freq = 20
            output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1))  # TODO
            print(output)
            log.write(output + '\n')
            log.flush()

    tf_writer.add_scalar('loss/train', losses.avg, epoch)
    tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
    tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
    tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)

Guess you like

Origin blog.csdn.net/better_boy/article/details/109003127