TSM (Temporal Shift Module) source code analysis
论文名:TSM: Temporal Shift Module for Efficient Video Understanding
Code link: https://github.com/mit-han-lab/temporal-shift-module
The main structure of the code is as follows:
python file | function |
---|---|
mian.py | Main training function |
opts.py | Code parameter configuration |
ops/dataset.py | The core of the loading part of the data set is the __getitem__ function. |
ops/dataset_config.py | Used to configure different data sets |
ops/models.py | Assembly model |
ops/temporal_shift.py | The core temporal shift operation |
1.opts.py is the parameter configuration. In addition to the path and hyperparameters, there are several parameters to pay attention to (different from TSN):
parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
– Modality indicates the type of input, RGB/FLOW
–Num_segments represents the number of frames per video sample, generally 8 or 16.
-Shift indicates whether to add the tsm module.
-Shift_div represents the ratio of shift features, generally 8. It means that the features of 2*1/8 scale will move, of which 1/8 features will be shift left, and the other 1/8 features will be shift right.
2.dataset_config.py is the data-level configuration.
Each data set implements a return_xxx(modality)
Returns information such as the subclass name supported by the data set, train_list path, val_list path, and root path of the data set.
3.dataset.py is the loading part of the dataset.
The main function of dataset.py is to read the dataset, sample it sparsely, and return the dataset obtained after sparse sampling.
The TSNDataSet class is implemented in dataset.py to process raw data. This class is inherited from the torch.utils.data.dataset class.
Which first defines a simple class:
3.1 VideoRecord, used to encapsulate a video content, including the path of the picture, the number of frames, and label information.
class VideoRecord(object):
def __init__(self, row):
self._data = row
@property
def path(self):
return self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
3.2 TSNDataSet:
class TSNDataSet(data.Dataset):
def __init__(self, root_path, list_file,
num_segments=3, new_length=1, modality='RGB',
image_tmpl='img_{:05d}.jpg', transform=None,
random_shift=True, test_mode=False,
remove_missing=False, dense_sample=False, twice_sample=False):
self.root_path = root_path
self.list_file = list_file
self.num_segments = num_segments
self.new_length = new_length
self.modality = modality
self.image_tmpl = image_tmpl
self.transform = transform
self.random_shift = random_shift
self.test_mode = test_mode
self.remove_missing = remove_missing
self.dense_sample = dense_sample # using dense sample as I3D
self.twice_sample = twice_sample # twice sample for more validation
if self.dense_sample:
print('=> Using dense sample for the dataset...')
if self.twice_sample:
print('=> Using twice sample for the dataset...')
if self.modality == 'RGBDiff':
self.new_length += 1 # Diff needs one more image to calculate diff
self._parse_list()
After setting some parameters and their default values, the _parse_list() function is called, and tmq is a list with the length of the training data. Each value is a VIDEORecord object, including a list and 3 attributes. The length of the list is 3, separated by the space bar, which are the frame path, how many frames the video contains, and the frame label. Then call the VideoRecord() function to write the content into a VideoRecord list.
tmp = [x.strip().split(' ') for x in open(self.list_file)]
self.video_list = [VideoRecord(item) for item in tmp]
3.3 TSNDataSet needs to implement the core function __getitem()
First, we need to get the image and label corresponding to the video. The format of the image is [n*t, c, h, w]
def __getitem__(self, index):
record = self.video_list[index]
So for a video, how to get num_segments frames?
if not self.test_mode: # test_mode: False; random_shift: True;
segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
else:
segment_indices = self._get_test_indices(record)
return self.get(record, segment_indices)
There are 3 different sampling functions (for train, test, val):
Among them, the _sample_indice function is for the sampling method of train; by default, the _sample_indice function will randomly take num_segments index, and there are two methods: dense sampling (dense) and sparse sampling (normal).
_get_test_indices will get num_segments index sparsely and fixedly.
To give a simple example, when num_frames=120, num_segments=3,
The normal sample in _sample_indices will randomly return: [4, 44, 84], [5, 45, 85], [11, 51, 91].
The dense sample in _sample_indices will randomly return: [15, 36, 57], [30, 51, 72], [44, 65, 86].
The same applies to dense_sample in _get_test_indices
The twice_sample in _get_test_indices will randomly return: [11,31,51,1,21,41]
Finally, each video samples num_segments frames, and the dimensions returned by getitem are: [n * t, c, h, w] (frames), 1 (label).
4.models.py
The main function of models.py is to prepare for the subsequent training model; first use some classic models as the basis, such as resnet50, for different input modalities, modify the last fully connected layer to get our TSN model, and It also introduces whether to add the TSM module, so as to get the TSM model we need.
The init function sets some parameters and parameter default values. The TSN model is obtained by calling the function to modify the model, and the TSM module function is called to obtain the TSM model. The init function is called
1 Call _prepare_base_model(base_model) to build a basic model
2 Call _prepare_tsn(num_class) to adapt the size of the fc layer according to the number of subclasses of different data sets
3 For the input of flow and rgbdiff, call _construct_flow_model and _construct_diff_model to change the size of the first convolution kernel
4.1 _prepare_base_model() function
And call the make_temporal_shit() function to add the tsm module:
def _prepare_base_model(self, base_model):
print('=> base model: {}'.format(base_model))
if 'resnet' in base_model:
# torchvision.models.resnet50
self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False)
if self.is_shift: # 默认false
print('Adding temporal shift...')
from ops.temporal_shift import make_temporal_shift
make_temporal_shift(self.base_model, self.num_segments,
n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)
if self.non_local: # 默认False
print('Adding non-local module...')
from ops.non_local import make_non_local
make_non_local(self.base_model, self.num_segments)
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
# 1代表输出特征图的大小;torch.Size([2, 32, 16, 16])----->torch.Size([2, 32, 1, 1])
self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)
4.2 _prepare_tsn function
The function of the _prepare_tsn function is to modify the known basemodel network structure and fine-tune the structure of the last layer (fully connected layer) to become a form suitable for the output of the data set.
def _prepare_tsn(self, num_class):
# 获取模型最后一层输入层的维度
feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
# 如果dropout==0,直接添加新的全连接层,输出维度是num_class
if self.dropout == 0:
setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
self.new_fc = None
# 如果有dropout!=0,添加dropout层,然后再添加全连接层,输出维度是num_class
else:
setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
self.new_fc = nn.Linear(feature_dim, num_class)
std = 0.001
if self.new_fc is None:
normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
else:
if hasattr(self.new_fc, 'weight'):
normal_(self.new_fc.weight, 0, std)
constant_(self.new_fc.bias, 0)
return feature_dim
4.3 The core of the make_temporal_shift() function TSM module
make_temporal_shift(self.base_model, self.num_segments,
n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)
So how is TemporalShift(b, n_segment=this_segment, n_div=n_div implemented?
For the input of [n*t, c, h, w], t is the value of the segment, first reshape into [n, t, c, h, w],
If the channel of the current feature map is 256 and fold_div=8, then 256/8 features are shifted left, and 256/8 features are shifted right. The features of the other part remain unchanged.
@staticmethod
def shift(x, n_segment, fold_div=3, inplace=False):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w)
fold = c // fold_div
if inplace:
# Due to some out of order error when performing parallel computing.
# May need to write a CUDA kernel.
raise NotImplementedError
# out = InplaceShift.apply(x, fold)
else:
out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
return out.view(nt, c, h, w)
Give a simple example:
When c = 8, num_segment=4, the features of 2 dimensions are expressed as follows:
0_xx represents the characteristics of the first frame. 1_xx codes the features of the second frame, each feature has 8 channels. The original features are as follows:
When fold_div = 8, the moving is as follows:
It can be seen that the features of the second frame are incorporated into the first frame, and the features of the third and second frames are incorporated into the second frame.
When fold_div=4, there will be more moving parts, that is, the features of the current frame will contain more information about the previous and next frames.
5.main.py is the main training function
Finally, we will explain the training main function, and connect the above-mentioned classes and functions in series.
5.1 Call dataset_config.return_dataset to obtain information about each path
5.2 Instantiate a good model and optimizer
5.3 Load the pre-trained model or recover from training
5.4 Prepare data train_loader, val_loader
5.5 Then start each epoch iteration, save a newest model and the best model
adjust_learning_rate, adjust the learning rate of each epoch according to the strategy.
for epoch in range(args.start_epoch, args.epochs):
# Sets the learning rate to the initial LR decayed by 10 every 30 epochs
adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
# evaluate on validation set
if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
print(output_best)
log_training.write(output_best + '\n')
log_training.flush() # 刷新缓冲区
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_prec1': best_prec1,
}, is_best)
The core functions of train are as follows:
def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
if args.no_partialbn: # 默认false
model.module.partialBN(False)
else:
model.module.partialBN(True)
# switch to train mode
model.train()
end = time.time() # 返回当前时间戳
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
raise RuntimeError
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# print('********************')
# print(loss)
# raise RuntimeError
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# compute gradient and do SGD step
loss.backward()
if args.clip_gradient is not None: # None
total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
optimizer.step()
optimizer.zero_grad()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0: # print_freq = 20
output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
print(output)
log.write(output + '\n')
log.flush()
tf_writer.add_scalar('loss/train', losses.avg, epoch)
tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)