pytorch实现手语识别

代码地址

直接给大佬的代码指路:GitHub_SLR

微调

大佬已经把轮子都造好了,自己写一个主函数,先加载数据集:

import dataset
import train

import torchvision.transforms as transforms


transform = transforms.Compose([transforms.Resize([128, 128]), transforms.ToTensor()])

dataset = dataset.CSL_Continuous(
    data_path="D:/Download/CSL_Continuous",
    dict_path="D:/Download/CSL_Continuous/dictionary.txt",
    corpus_path="D:/Download/CSL_Continuous/corpus.txt",
    train=True, transform=transform
    )

print(len(dataset))
images, tokens = dataset[1000]
print(images.shape, tokens)
print(dataset.output_dim)

先准备好数据集,更改目录:

  1. CSL:Chinese Sign Language Recognition Dataset
  2. dictionary.txt源代码中没有自动生成,作者在issues中给出:dictionary.txt

错误:No module named 'tensorboard'
解决方法:pip install tensorboard

错误:'gbk' codec can't decode byte 0x80 in position 25: illegal multibyte sequence以及'gbk' codec can't decode byte 0xad in position 17: illegal multibyte sequence
解决方法:是同一个问题,由于gbk编码和utf-8编码的不同导致的,在dataset.py文件中的215和242行打开txt文件时,后面加上encoding='utf-8'即可

错误:list index out of range
解决方法:这是因为作者对原始数据进行了预处理,将连续手语数据集中的视频全部变成了图像,所以在加载数据集时直接读取了图像,但是我们在下载数据后只是视频,为了真正实现端到端的训练,我对dataset进行了更改,使其对视频文件进行索引的时候可以直接转化为一组图片

修改后代码

dataset

class CSL_Continuous(Dataset):
    def __init__(self, data_path, dict_path, corpus_path, frames=12, train=True, transform=None):
        super(CSL_Continuous, self).__init__()
        # 3个路径
        self.data_path = data_path
        self.dict_path = dict_path
        self.corpus_path = corpus_path
        # 帧数在读取图像时用到
        self.frames = frames
        # 模式,变换
        self.train = train
        self.transform = transform
        # 其他参数
        self.num_sentences = 100
        self.signers = 50
        self.repetition = 5

        # 根据任务不同划分训练集测试集的大小,0.8*50*5=200,训练集每个句子对应200个样本
        if self.train:
            self.videos_per_folder = int(0.8 * self.signers * self.repetition)
        else:
            self.videos_per_folder = int(0.2 * self.signers * self.repetition)

        # dictionary
        self.dict = {
    
    '<pad>': 0, '<sos>': 1, '<eos>': 2}
        self.output_dim = 3
        try:
            dict_file = open(self.dict_path, 'r', encoding='utf-8')
            for line in dict_file.readlines():
                line = line.strip().split('\t')
                # word with multiple expressions
                if '(' in line[1] and ')' in line[1]:
                    for delimeter in ['(', ')', '、']:
                        line[1] = line[1].replace(delimeter, " ")
                    words = line[1].split()
                else:
                    words = [line[1]]
                for word in words:
                    self.dict[word] = self.output_dim
                self.output_dim += 1
        except Exception as e:
            raise

        # img data
        self.data_folder = []
        try:
            # 列出data_path下所有文件,obs_path包括所有item的路径
            obs_path = [os.path.join(self.data_path, item) for item in os.listdir(self.data_path)]
            self.data_folder = sorted([item for item in obs_path if os.path.isdir(item)])
        except Exception as e:
            raise
        # print(self.data_folder[1]) # 就是000000-000099的目录,这里是\\,加了索引就变成了\

        # corpus
        self.corpus = {
    
    }
        self.unknown = set()
        try:
            corpus_file = open(self.corpus_path, 'r', encoding='utf-8')
            for line in corpus_file.readlines():
                line = line.strip().split()
                sentence = line[1]
                raw_sentence = (line[1]+'.')[:-1]
                paired = [False for i in range(len(line[1]))]
                # print(id(raw_sentence), id(line[1]), id(sentence))
                # pair long words with higher priority
                for token in sorted(self.dict, key=len, reverse=True):
                    index = raw_sentence.find(token)
                    # print(index, line[1])
                    if index != -1 and not paired[index]:
                        line[1] = line[1].replace(token, " "+token+" ")
                        # mark as paired
                        for i in range(len(token)):
                            paired[index+i] = True
                # add sos
                tokens = [self.dict['<sos>']]
                for token in line[1].split():
                    if token in self.dict:
                        tokens.append(self.dict[token])
                    else:
                        self.unknown.add(token)
                # add eos
                tokens.append(self.dict['<eos>'])
                self.corpus[line[0]] = tokens
        except Exception as e:
            raise

        # add padding
        length = [len(tokens) for key, tokens in self.corpus.items()]
        self.max_length = max(length)
        # print(max(length))
        for key, tokens in self.corpus.items():
            if len(tokens) < self.max_length:
                tokens.extend([self.dict['<pad>']]*(self.max_length-len(tokens)))
        # print(self.corpus)
        # print(self.unknown)

    def read_images(self, folder_path):
        # 在条件不满足程序运行的情况下直接返回错误,而不必等待程序运行后出现崩溃
        # assert len(os.listdir(folder_path)) >= self.frames, "Too few images in your data folder: " + str(folder_path)

        images = [] # list
        capture = cv2.VideoCapture(folder_path)

        # fps = capture.get(cv2.CAP_PROP_FPS)
        fps_all = capture.get(cv2.CAP_PROP_FRAME_COUNT)
        # 取整数部分
        timeF = int(fps_all/self.frames)
        n = 1

        # 对一个视频文件进行操作
        while capture.isOpened():
            ret, frame = capture.read()
            if ret is False:
                break
            # 每隔timeF帧进行存储操作
            if (n % timeF == 0):
                image = frame # frame是PIL
                image = Image.fromarray(image) # np array
                if self.transform is not None:
                    image = self.transform(image) # tensor
                images.append(image)
            n = n + 1
            # cv2.waitKey(1)
        capture.release()
        # print('读取视频完成')
        # print("采样间隔:", timeF)

        lenB = len(images)
        # 将列表随机去除一部分元素,剩下的顺序不变

        for o in range(0, int(lenB-self.frames)):
            # 删除一个长度内随机索引对应的元素,不包括len(images)即不会超出索引
            del images[np.random.randint(0, len(images))]
            # images.pop(np.random.randint(0, len(images)))
        lenF = len(images)

        # 沿着一个新维度对输入张量序列进行连接,序列中所有的张量都应该为相同形状
        images = torch.stack(images, dim=0)
        # 原本是帧,通道,h,w,需要换成可供3D CNN使用的形状
        images = images.permute(1, 0, 2, 3)

        print("数据类型:", images.dtype)
        print("图像形状:", images.shape)
        print("总帧数:%d, 采样后帧数:%d, 抽帧后帧数:%d" % (fps_all, lenB, lenF))

        return images

    def __len__(self):
        # 100*200=20000
        return self.num_sentences * self.videos_per_folder

    def __getitem__(self, idx):
        # 根据索引确定访问的文件夹,1000为第5个文件夹,就是obs_path中的某个
        # 新思路,索引就是样本,哪个样本就是哪个文件夹,在索引前面补充0至6位
        s = "%06d" % int(idx/self.videos_per_folder)
        top_folder = os.path.join(self.data_path, s)
        # top_folder = self.data_folder[int(idx/self.videos_per_folder)]
        # top_folder 'D:/Download/CSL_Continuous/color\\000005'
        # os.listdir 用于返回指定的文件夹包含的文件或文件夹的名字的列表

        # selected_folders就是文件夹内全部视频的路径
        selected_folders = [os.path.join(top_folder, item) for item in os.listdir(top_folder)]
        # sorted可以对所有可迭代的对象进行排序操作,但是结果表明此列表不可迭代
        # selected_folders = sorted([item for item in selected_folders_s if os.path.isdir(item)])

        # print(selected_folders)
        # 根据索引选定一个视频文件
        if self.train:
            selected_folder = selected_folders[idx%self.videos_per_folder]
        else:
            selected_folder = selected_folders[idx%self.videos_per_folder + int(0.8*self.signers*self.repetition)]
        # 给定文件夹(索引类别)进行读取,其中250个视频(否)
        images = self.read_images(selected_folder)

        # print(selected_folder, int(idx/self.videos_per_folder))
        # print(self.corpus['{:06d}'.format(int(idx/self.videos_per_folder))])
        tokens = torch.LongTensor(self.corpus['{:06d}'.format(int(idx/self.videos_per_folder))])
        len_label = len(tokens)

        dict_file = open(self.dict_path, 'r', encoding='utf-8')
        len_voc = len(dict_file.readlines()) + 2

        print("标签长度:%d 词典长度: %d" % (len_label, len_voc))

        return images, tokens, len_label, len_voc

之后直接进行运行CSL_Continuous_Seq2Seq.py即可。

猜你喜欢

转载自blog.csdn.net/qq_45510888/article/details/116380366