TSN实验配置

TSN实验环境

实验所需性能较高,这里选择在网络平台上进行实验。
注: 与在Linux系统不同,在此平台上运行命令需要在命令前加上!

一、准备工作

1.1 数据集准备

本次实验使用UCF101数据集
动作识别数据集,从youtube收集而得,共包含101类动作。其中每类动作由25个人做动作,每人做4-7组,共13320个视频,分辨率为320*240,共6.5G。
UCF101在动作的采集上具有非常大的多样性,包括相机运行、外观变化、姿态变化、物体比例变化、背景变化、光纤变化等。
在下面给出的下载链接中有官方对数据集更详细的介绍,这里不再赘述。

下载地址:https://www.crcv.ucf.edu/research/data-sets/ucf101/
数据集下载

下载完成后,UCF-101文件夹共101个子目录,其中每个子目录分别有若干相同动作的视频,ucfTrainTestlist文件夹下共有3中划分方式,在后续实验过程中自行选择。

1.2 代码准备

需要使用tsn-pytorch和mmaction的代码,这里直接从GitHub上拉取

  • tsn-pytorch
    !git clone --recursive https://github.com/yjxiong/tsn-pytorch
    
  • mmaction
    !git clone --recursive https://github.com/open-mmlab/mmaction.git
    

二、处理数据

首先在mmaction/data/下创建一个UCF101文件夹,用于存放数据集相关文件,并建立如下四个目录
数据放置
它们的作用为:

  • annotations:ucf101之后进行分割训练集、测试集的依据文件
  • rawframes:视频提帧后存放的文件目录
  • videos:复制ucf101数据集中的101个文件目录(考虑平台存储空间和运行时间,这里仅拷贝3个)

2.1 提帧

进入mmaction/data_tools/目录下,查看视频提帧的代码文件
bulid_rawframes.py
运行之前需要安装mmcv工具包,根据pytorch和cuda版本安装

!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.9.0/index.html

然后运行此命令,对videos目录下视频提帧,结果放在rawframes文件夹下

!python build_rawframes.py ../data/ucf101/videos ../data/ucf101/rawframes/ --level 2  --ext avi

这里只拷贝了UCF101中3种类型的视频,需要15分钟左右,提帧完成后在rawframes目录下会有3个对应的文件夹
提帧完成

这一步对每个视频进行帧率提取,将视频分割成图片,随机打开一个文件了,里面全部为视频中截取的图片
视频提帧

2.2 生成file_list

切换工作目录到mmaction/data_tools/ucf101/中,找到generate_filelist.sh
generate_filelist
运行

!bash generate_filelist.sh

在这里插入图片描述
运行结束,会在ucf101目录下生成分割文件file_list。(这里批处理文件同时处理了videos,而我们需要的是划分rawframes,可以忽略videos)
分割文件
打开其中一个文件,内容如下:
文件内容
file_list中有三列,第一列代表文件的地址,第二列代表视频的帧数,第三列代表视频的类别。这里仅仅使用ucf101数据集的3个文件夹,所以类别只有0 1 2。

注意: 这一步运行可能出现缺少mmaction相关文件的错误,需要先将目录切换至mmaction下,找到setup.py文件,执行如下命令

!python setup.py install

等待安装完成后即可正常运行。

三、训练工作

3.1 修改代码

在训练之前,需要对tsn-pytorch的一些内容做相应的修改

修改tsn-pytorch中的main.py

  1. ucf有101个类别,这里仅测试了3种所以修改代码
    if args.dataset == 'ucf101':
        num_class = 3
    
  2. 在TSNDataSet中,将将args.train_list和args.val_list改为2.2生成文件的绝对路径(任选一个),所以将代码修改为
    训练路径
    train_loader = torch.utils.data.DataLoader(
    	# 修改部分
        TSNDataSet("", "/content/drive/MyDrive/TSN/mmaction/data/ucf101/ucf101_train_split_1_rawframes.txt", num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    
    验证路径
    val_loader = torch.utils.data.DataLoader(
    	# 修改部分
        TSNDataSet("", "/content/drive/MyDrive/TSN/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt", num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
  3. 因为我们只选取了三个类别,而pytorch的topk算法默认是(1,5),所以还要对这部分进行更改,否则会报越界错误,更改部分如下图
    topk更改
    可以搜索在main.py中搜索topk关键字,将其改为(1,3),为了方便,前面的prec5不再做更改,在后期输出时注意区分即可。

同时需要更改datase.py中的对应路径
在get函数的self._load_image时要加入绝对路径进行定位

	def get(self, record, indices):

        images = list()
        for seg_ind in indices:
            p = int(seg_ind)
            for i in range(self.new_length):
            	# 修改部分 
                seg_imgs = self._load_image("/content/drive/My Drive/TSN/mmaction/data/ucf101/rawframes"+record.path, p)
                images.extend(seg_imgs)
                if p < record.num_frames:
                    p += 1

        process_data = self.transform(images)
        return process_data, record.label

3.2 开始训练

在tsn-pytorch的README中有训练和测试命令,其中训练命令如下
训练命令
这里选择第一种,即RGB模态下分段为3的网络。
切换到tsn-pytorch目录下,运行下面的命令开始训练,训练5个epoch。

!python main.py ucf101 RGB /content/drive/MyDrive/TSN/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt /content/drive/MyDrive/TSN/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt \
   --num_segments 3 \
   --gd 20 --lr 0.001 --lr_steps 30 60 --epochs 5 \
   -b 8 -j 8 --dropout 0.8 \
   --snapshot_pref ucf101_bninception_ 

运行后可看到一些预设参数

预设参数

训练结果如下:

训练结果
最终训练准确率为100%。

模型文件保存
模型文件

四、测试工作

4.1 修改代码

测试需要对tsn-pytorch目录下test_models.py进行一些修改

  1. 修改分段数,减少运算量
    修改分段数

  2. 修改ucf类别数
    ucf类别数

  3. 修改为单gpu运行
    单gpu运行

4.2 开始测试

同样从上述README文件中获取测试命令,执行

!python test_models.py ucf101 RGB /content/drive/MyDrive/TSN/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt ucf101_bninception__rgb_checkpoint.pth.tar \

首先打印TSN相关配置

TSN测试配置

测试过程很快,测试结果如下:

测试结果
可以看到最终测试准确率为87.96%。

五、一些错误记录

在部署过程中,主要修改点都在训练代码的修改上,主要有一下几点。

  1. 训练时,运行main.py出现以下错误
    错误1
    将target = target.cuda(async=True)改为target = target.cuda()

  2. 在运行训练命令后,可能会有如下报错:

    RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. 
    

    这个是因为当前的pytorch版本过高,而原代码的版本较低。如果pytorch版本高于1.3会出现该问题,当前版本要求forward过程是静态的,所以需要将原代码进行修改。需要更改tsn-pytorch目录下对定义模型的models.py,在forward函数中更改ConsensusModule的调用方式
    models修改
    并且更改main.py文件下train函数对TSN模型的调用方式
    TSN修改
    然后需要对tsn-pytorch/ops目录下basic_ops.py代码进行更改,修改方式可参考文章自定义autograd function

  3. 进行训练时,由于pytorch版本不同,可以还有其他小问题,例如用多卡训练的时候tensor不连续,需要对tensor进行操作前先调用contiguous()方法(eg. tensor.contiguous().view()),这些都很容易解决,不再一一列举。

猜你喜欢

转载自blog.csdn.net/qq_41533576/article/details/119751330
TSN
今日推荐