基于tflearn使用lstm实现文本分类

模型训练部分代码

# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tflearn
import os
import numpy
import jieba
import sys
import random
import re
import fire
from sys import argv
import json
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import imdb

def load_data1( keywordPath,stopwords_set,filepath,dictfilepath,n_words, valid_portion=0.1,
               sort_by_len=True):
   

    #keywordPath = sys.argv[1]
    jieba.load_userdict(keywordPath)
    pathDir = os.listdir(filepath)

    data_set = []
    train_set_x = []
    train_set_y = []
    test_set_x = []
    test_set_y = []

    # 把停用词做成字典
    stopwords = {}
    fstop = open(stopwords_set, 'rb')
    for eachWord in fstop:
        stopwords[eachWord.strip().decode('utf-8', 'ignore')] = eachWord.strip().decode('utf-8', 'ignore')
    fstop.close()

    #写入词典
    f1 = open(dictfilepath, 'w', encoding='UTF-8')
    dic = dict()

    i = 0
    j = 0

    #构建词表
    for allDir in pathDir:
        child = filepath + allDir
        if os.path.isdir(child):
            pathSubDir = os.listdir(child)
            k = 1
            for subDir in pathSubDir:
                # if m >5000:
                #     break
                des = child + os.sep + subDir
                s1 = ""
                invert = []
                fOpen = open(des, "r", encoding='UTF-8')
                for eachLine in fOpen:
                    line = eachLine.strip()
                    line1 = re.sub("[\s+\.\!\/_,$%^*()?;;:-【】+\"\']+|[+——!,;:。?、~@#¥%……&*()]+",
                                   "", line)
                    wordList = list(jieba.cut(line1))
                    for word in wordList:
                        if word not in stopwords:
                            data_set.append(word)
                            if word not in dic:
                                i = i + 1
                                dic[word] = i
                                invert.append(dic[word])#append到invertlist,invert[22,123,424,..],文档word编码集合
                                if re.match('[^ \t\n\x0B\f\r]', word, flags=0):
                                    f1.write(word+" "+str(i))
                                    f1.write("\n")
                            else:
                                invert.append(dic[word])
                    j = j+1
                #if random.randint(1, 10) == 1:#false
                n = len(pathSubDir)
                if k <= n*0.1:
                    print(str(j)+" test "+allDir)
                    test_set_x.append(invert)
                    test_set_y.append(allDir)
                else:
                    print(str(j) + " train " + allDir)
                    train_set_x.append(invert)
                    train_set_y.append(allDir)
                k += 1
                fOpen.close()
    f1.close()

    print("the number of words : "+str(i))

    n_samples = len(train_set_x)
    sidx = numpy.random.permutation(n_samples)
    n_train = int(numpy.round(n_samples * (1. - valid_portion)))
    valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
    valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
    train_set_x = [train_set_x[s] for s in sidx[:n_train]]
    train_set_y = [train_set_y[s] for s in sidx[:n_train]]

    train_set = (train_set_x, train_set_y)
    valid_set = (valid_set_x, valid_set_y)

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    valid_set_x, valid_set_y = valid_set
    train_set_x, train_set_y = train_set

    train_set_x = remove_unk(train_set_x)
    valid_set_x = remove_unk(valid_set_x)
    test_set_x = remove_unk(test_set_x)


    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(test_set_x)
        test_set_x = [test_set_x[i] for i in sorted_index]
        test_set_y = [test_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(valid_set_x)
        valid_set_x = [valid_set_x[i] for i in sorted_index]
        valid_set_y = [valid_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(train_set_x)
        train_set_x = [train_set_x[i] for i in sorted_index]
        train_set_y = [train_set_y[i] for i in sorted_index]

    train = (train_set_x, train_set_y)
    valid = (valid_set_x, valid_set_y)
    test = (test_set_x, test_set_y)
    return train, valid, test

def train():
    print("#######################")
    print("#         train       #")
    print("#######################")
    words = []
    s = os.sep  # 更改路径操作符
    keywordPath = sys.argv[1]
    dictPath = sys.argv[2]
    f = open(dictPath, "r", encoding="utf-8")
    for i in f:
        words.append(i)
    word_num = len(words)
    modelPath = sys.argv[3]
    stopword_setPath = sys.argv[4]
    classnum = int(sys.argv[5])
    dataPath = "d:" + s + "data" 
    train, valid, test = load_data1(keywordPath=keywordPath, stopwords_set=stopword_setPath, filepath=dataPath, dictfilepath=dictPath, n_words=word_num, valid_portion=0.1)
    trainX, trainY = train
    valX, valY = valid
    trainX = pad_sequences(trainX, maxlen=30, value=0.)
    valX = pad_sequences(valX, maxlen=30, value=0.)

    trainY = to_categorical(trainY, nb_classes=classnum)
    valY = to_categorical(valY, nb_classes=classnum)

    net = tflearn.input_data([None, 30])
    net = tflearn.embedding(net, input_dim=word_num, output_dim=128)
    net = tflearn.lstm(net, 128, dropout=0.8)
    net = tflearn.fully_connected(net, classnum, activation='softmax')
    net = tflearn.regression(net, optimizer='adam', learning_rate=0.01,
                             loss='categorical_crossentropy')

    model = tflearn.DNN(net, tensorboard_verbose=0)
    model.fit(trainX, trainY, n_epoch=1, validation_set=(valX, valY), show_metric=True, batch_size=256)
    model.save(modelPath)


if __name__ == '__main__':
    fire.Fire(train)

猜你喜欢

转载自blog.csdn.net/qq_20780183/article/details/80178995