深度学习vad人声检测之训练数据加载

vad训练数据的特征在前面的博客已经准备好了,下面就是利用准备好的数据进行模型训练,训练主要包含了下面几个步骤:
(1)加载数据
(2)搭建模型
(3)模型训练

加载数据:主要是加载训练数据的特征和对应标签,具体实现如下:

import torch
from torch.utils import data
import json
import os
import soundfile as sf
import numpy as np
import math


def read_audio_file1(path,fmt,flag = 0):

    files = []
    names = []
    for root ,dir ,filenames in os.walk(path):
        #print('root = ',root)
        #print('filename.len = ',len(filenames))
        for filename in filenames:
            if filename.endswith(fmt):
                #print('filename = ',filename)
                file_path = root + '/' + filename
                files.append(file_path)

                filename = filename.split('.')[0]

                if(flag == 1):
                    name = file_path.split('.')[0]
                    name = name.split('/')
                    filename = name[-3] + '_' + name[-2] + '_' + name[-1]
                names.append(filename)

    return files,names


def read_audio_labels(params):

    f0 = open(params['labels']['speech_label_dir'],'r')
    f1 = open(params['labels']['noise_label_dir'],'r')
    #f2 = open(params['labels']['noise_label_dir2'],'r')

    label_dict = dict()
    lines = f0.readlines()
    for line in lines:
        line = line.strip().split(':')
        name = line[0]
        #print('name = ',name)
        label = line[-1].split(' ')
        
        tmp = []
        for l in label:
            l = int(l)
            tmp.append(l)
        
        label_dict[name] = tmp
    
    #print('label_dict = ',label_dict)
    
    lines = f1.readlines()
    for line in lines:
        line = line.strip().split(':')
        name = line[0]
        #print('name = ',name)
        label = line[-1].split(' ')

        tmp = []
        for l in label:
            l = int(l)
            tmp.append(l)
        
        label_dict[name] = tmp
    
    # lines = f2.readlines()
    # for line in lines:
    #     line = line.strip().split(':')
    #     name = line[0]
    #     #print('name = ',name)
    #     label = line[-1].split(' ')

    #     tmp = []
    #     for l in label:
    #         l = int(l)
    #         tmp.append(l)
        
    #     label_dict[name] = tmp

    #print('label_dict = ',len(label_dict))
    
    return label_dict
    




def read_file_feats(path,snr):

    fp = open(path,'r')
    lines = fp.readlines()

    feats = []
    snrs = []
    for line in lines:
        line = line.strip().split(' ')
        tmp = []
        for l in line:
            tmp.append(float(l))
        feats.append(np.array(tmp,dtype = np.float32))
        snrs.append(snr)

    fp.close()

    return feats,snrs


def read_file_labels(path,snr):

    fp = open(path,'r')
    lines = fp.readlines()

    labels = []
    snrs = []
    for line in lines:
        line = line.strip().split(' ')
        tmp = []
        for l in line:
            labels.append(int(l))
            snrs.append(snr)
        #labels.append(tmp)

    fp.close()

    return labels,snrs


def get_feats_labels(dpath,snr):

    # feats,snrs = read_file_feats(dpath,snr)

    # feats = np.array(feats,dtype = np.float32)

    feats = np.load(dpath)
    
    return feats



class VADDataset(data.Dataset):
    dataset_name = 'VAD'

    def __init__(self, params,mode,ex_num = 16):
        super(VADDataset, self).__init__()

        self.params = params
        self.mode = mode
        self.ex_num = ex_num
        self.dpaths, self.names = read_audio_file1(self.params[mode]['read_feats_dir'],'.npy')
        self.labels_dict = read_audio_labels(self.params)
        self.frame_num = 2999

        print('dpaths.len = ',len(self.dpaths))
        print('labels.len = ',len(self.labels_dict))
        

    def __len__(self):
        return len(self.dpaths)

    def __getitem__(self, idx):
        
        dpath = self.dpaths[idx]
        name = self.names[idx]
        
        name1 = name.split('snr')[-1]
        snr = name1.split('_')[0]

        if snr.isalpha():
            snr = 100
        else:
            snr = int(snr)
        
        #print('dpath = ',dpath)
        feats = np.load(dpath,allow_pickle=True)
        feats = feats.transpose()
        #print('feats.shape = ',feats.shape)

        name = name.split('_snr')[0]
        name = name.split('_noisy')[0]
        #print('name = ',name)
        labels = self.labels_dict[name]
        labels = np.array(labels[:self.frame_num])


        feats = torch.from_numpy(feats)
        labels = torch.from_numpy(labels)


        return feats,labels

    
    def collate_fn(self,batch):
    '''
    主要用于各个特征文件长度不一致的情形下,需要将文件的长度补零到一致长度
    '''
        feats,labels,snrs = list(zip(*batch))
        logest_tensor = max([feats[i].shape[1] for i in range(len(feats))])
        #print('logest_tensor = ',logest_tensor)

        s = len(feats)
        feats_len = []
        labels_len = []
        new_feats = []
        new_labels = []
        new_snrs = []

        for feat,label,snr in batch:
            feats_len.append(feat.shape[1])
            labels_len.append(len(label))
            
            feat = np.pad(feat,((0,0),(0,logest_tensor - feat.shape[1])),'constant')
            new_feats.append(feat)
            new_labels = new_labels + label
            new_snrs = new_snrs + snr


        new_feats = np.array(new_feats)
        new_labels = np.array(new_labels)
        new_snrs = np.array(new_snrs)

        new_feats = torch.from_numpy(new_feats)
        new_labels = torch.from_numpy(new_labels)
        new_snrs = torch.from_numpy(new_snrs)

        if self.mode == 'test':
            return new_feats,new_labels,feats_len,new_snrs

        return new_feats,new_labels,feats_len

水平有限,不当之处还请指教,谢谢!

猜你喜欢

转载自blog.csdn.net/pikaqiu_n95/article/details/114197279
今日推荐