稀疏表示学习字典实现语种识别

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011930705/article/details/81738327

参考论文:Singh O P, Sinha R. Sparse coding of i-vector/JFA latent vector over ensemble dictionaries for language identification systems[J]. International Journal of Speech Technology, 2017(11):1-16.

实现通过kaldi的ivector语种识别搭配svm获得svmtrain和svmtest作为训练和测试数据,

然后构造完备字典,通过k-svd算法更新字典,类似k-means算法,然后通过lasso计算稀疏系统,

下面进行打分,并通过kaldi工具compute-eer计算eer

最后结果:

#origin result eer:5.03813%
#result eer : 5.0938%    (0.4,2)

并没有论文说的那么好,以后再看看

#!/usr/bin/env python
#coding:utf-8

import os, sys, math
import numpy as np
from sklearn.linear_model import Lasso
from Functions import *

num_lang = 10
num_round = 20
num_round_lang = 120


def get_scores(num_lang, coeff, data_lable):
  #print coeff
  scores = []
  for i in xrange(0, num_lang):
    lang_idx = np.where(data_lable == i)
    scores.append(np.sum(coeff[lang_idx]))
  #返回每个语种的得分【语种1得分,语种2得分,。。。,语种n得分】
  return np.array(scores)


if __name__ == "__main__":

	fin_train = sys.argv[1]
	fin_test = sys.argv[2]
        fout_scores = sys.argv[3]

	sys.stderr.write('Paras: ' + fin_train + ' ' + fin_test + ' '  + '\n')

	# data preparation for training set
	fin = open(fin_train, 'r')

	train_data = []
	train_lable = []

	for line in fin:
		#数据的每行第一列标识语种编号,后面就是实际数据,然后把实际数据存到train_data,标识编号存到train_lable
		line = line.strip()
		wordList = line.split()
		tempList = []
		tempList = np.array(wordList[1:], dtype=float)
                train_data.append(tempList)
                train_lable.append(np.array(wordList[0], dtype=int))

	sys.stderr.write('train_data: ' + str(len(train_data)) + ' ' + 'train_lable: ' + str(len(train_lable)) + '\n')

	fin.close()

	# construct dictionary using training data
        #train_data这里行数就是语段数量,列就是维数,通过lable来标示每段语种的编号是哪种语言
        train_data = np.array(train_data)
        train_lable = np.array(train_lable)
 

	# data preparation for test target test
	#测试数据准备同上
	test_data = []
	test_lable = []

	fin = open(fin_test, 'r')
	
	for line in fin:
		line = line.strip()
		wordList = line.split()
		tempList = []
		tempList = np.array(wordList[1:], dtype=float)
                test_data.append(tempList)
                test_lable.append(np.array(wordList[0], dtype=int))

	sys.stderr.write('test_data: ' + str(len(test_data)) + ' ' + 'test_lable: ' + str(len(train_lable)) + '\n')
        
      
	fin.close()	

        #using random sampling 
          
	matrix_all=[]
	matrix_label_all=[]
	matrix_all_svd=[]
	matrix_label_all_svd=[]
	for i_round in range(0, num_round):
	    
	    matrix_round = []
	    matrix_round_svd = []
	    curr_label_svd = []
	    label_round = []
	    label_round_svd = []
	    for i_num_lang in range(0,num_lang):
                #下面几句作用为了生成每个语种标签所对应ivector数据。
		lang_idx = np.where(train_lable == i_num_lang)
		lang_len = np.array(lang_idx).shape[1]
		#print lang_len
		curr_data = train_data[lang_idx]

		curr_label = train_lable[lang_idx]
		#print curr_label
		#随机选取语段,然后通过num_round_lang控制长度,然后排序matrix_round和label_round是每个语种小范围的字典
		shuffle_lang = range(0, lang_len)
		np.random.shuffle(shuffle_lang)
		shuffle_lang = shuffle_lang[:num_round_lang]
		shuffle_lang = np.sort(shuffle_lang)
		#print shuffle_lang
		#print train_lable[shuffle_lang]
		lang_idx = shuffle_lang
		#curr_label = curr_label[:num_round_lang]
		#matrix_round.append( curr_data[lang_idx])
		curr_label = curr_label[:num_round_lang/2]
		curr_label_svd = curr_label[num_round_lang/2:num_round_lang]
		matrix_round.append( curr_data[lang_idx[range(0,num_round_lang/2)]])
		matrix_round_svd.append( curr_data[lang_idx[range(num_round_lang/2,num_round_lang)]])
		label_round.append(curr_label)
		label_round_svd.append(curr_label)
	    #matrix_all就是所有语种完全的字典,matrix_label_all是语种标签
	    matrix_all.append(np.vstack(matrix_round))
	    matrix_label_all.append(np.hstack(label_round))
	    matrix_all_svd.append(np.vstack(matrix_round_svd))
	    matrix_label_all_svd.append(np.hstack(label_round_svd))
	matrix_all = np.array(matrix_all)
	matrix_label_all =  np.array(matrix_label_all)
	matrix_all_svd = np.array(matrix_all_svd)
	matrix_label_all_svd =  np.array(matrix_label_all_svd)
        sigma=0.4
	ksvd_iter=2


        #using lasso
	#通过lasso算法求稀疏矩阵,具体参考下面文档
	#http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html#sklearn.linear_model.Lasso
        fout = open(fout_scores , 'w')

        alpha = 0.01
        lasso = Lasso(alpha=alpha,max_iter=1000)
        
        np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
	# predict

	y1=[]
	basis_final_all = []
	#y1 = np.array(test_data[:matrix_all[i_round].shape[0]], dtype=float)
	round_result = np.zeros(num_lang)
	for i_round in range(0, num_round):
		basis = matrix_all[i_round]
		y1 = matrix_all_svd[i_round]
		print basis.shape,y1.shape
                #调用github开源算法更新稀疏矩阵:https://github.com/alsoltani/K-SVD.git
		basis_final, sparse_final, n_total = k_svd(y1.T, basis.T, sigma, single_channel_omp, ksvd_iter,'no')
		basis_final_all.append(np.vstack(basis_final))
	matrix_basis_all = np.array(basis_final_all)
		#print "========================================="
		#for i in range(sparse_final.shape[1]):
	 	#  sc = get_scores( num_lang, sparse_final.T[i] , matrix_label_all[i_round])
           	#  round_result = round_result + sc
          
                #round_result = 1.0 * round_result / num_lang

               # fout.write( str(round_result) + '\n')
                #print str(round_result)
                #if i % 100 == 0:
         	#   fout.flush()

	# predict
        #fout.close()
				
	#print basis_final

	for i in range(len(test_data)):
          lasso.fit(train_data.T, test_data[i])
          round_result = np.zeros(num_lang)
          for i_round in range(0, num_round):
	    #通过lasso计算稀疏矩阵,通过get_scores计算测试数据对应每个语种的得分
            #y_pred_lasso = lasso.fit( matrix_all[i_round].T, test_data[i])
	    y_pred_lasso = lasso.fit( matrix_basis_all[i_round], test_data[i])
            sc = get_scores( num_lang, y_pred_lasso.coef_ , matrix_label_all[i_round])
            round_result = round_result + sc
          
          round_result = 1.0 * round_result / num_lang

          fout.write( str(round_result) + '\n')
          print str(round_result)
          if i % 100 == 0:
            fout.flush()

	# predict
        fout.close()


#origin result eer:5.03813%
#result eer : 5.0938%    (0.4,2)

	

猜你喜欢

转载自blog.csdn.net/u011930705/article/details/81738327
今日推荐