Tensorflow - 将序列处理成embedding - 方法1 - keras调包

 I will use several examples to show the data processing methods.

Apart from generating embedding, the process of training model in epoch level and generating batch data is also useful.

Case 1 - Text classification with LSTM

import tensorflow as tf
import os
from text_classify_rnn_cnn.data.cnews_loader import read_vocab, read_category, batch_iter, process_file, process_predict_file
import numpy as np
from sklearn import metrics


seq_length = 600
num_classes = 10
vocab_size = 5000
embedding_dim = 600

num_layers = 2
hidden_dim = 128
dropout_keep_prob = 0.8
learning_rate = 0.00005

batch_size = 128
num_epochs = 10

base_dir = 'data/cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')  # vocab.txt already generated

input_x = tf.placeholder(tf.int32, [None, seq_length])
input_y = tf.placeholder(tf.float32, [None, num_classes])
keep_prob = tf.placeholder(tf.float32)


def text_rnn_mlp():

    embedding = tf.get_variable('embedding', [vocab_size, embedding_dim])
    embedding_input = tf.nn.embedding_lookup(embedding, input_x)

    def lstm_cell():
        cell = tf.contrib.rnn.BasicLSTMCell(hidden_dim, state_is_tuple=True)
        dropout = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
        return dropout

    cells = [lstm_cell() for _ in range(num_layers)]  # cell的一个list而已
    rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
    _output, state = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_input, dtype=tf.float32)
    last_output = _output[:, -1, :]

    # mlp
    fc1 = tf.layers.dense(inputs=last_output, units=hidden_dim, name='fc1')
    fc1_dropout = tf.contrib.layers.dropout(fc1, keep_prob)
    fc1_relu = tf.nn.relu(fc1_dropout)

    logits = tf.layers.dense(inputs=fc1_relu, units=num_classes, name='fc2')
    y_pred = tf.argmax(tf.nn.softmax(logits), 1)

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits= logits, labels=input_y)
    loss = tf.reduce_mean(cross_entropy)
    train_steps = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

    correct_pred = tf.equal(y_pred, tf.argmax(input_y, 1))
    acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return loss, train_steps, acc, y_pred


def train_model(state_flag):
    total_train_acc = 0.0
    loss, train_steps, acc, y_pred = text_rnn_mlp()
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    session = tf.Session()
    session.run(tf.global_variables_initializer())

    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, seq_length)
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, seq_length)
    print('len(x_train)', len(x_train), 'len(x_val)', len(x_val))
    print(x_train[0:2])

    if state_flag == 'train':
        for i in range(num_epochs):
            batch_train = batch_iter(x_train, y_train, batch_size)

            for x_batch, y_batch in batch_train:
                batch_num = len(x_batch)
                session.run(train_steps, feed_dict={input_x: x_batch, input_y: y_batch, keep_prob: 0.75})
                loss2, accuracy2 = session.run([loss, acc], feed_dict={input_x: x_batch, input_y: y_batch, keep_prob: 0.75})
                print("accuracy : %f ,  loss : %g" % (accuracy2, loss2))
                total_train_acc = total_train_acc + accuracy2*batch_num

            train_acc = total_train_acc / len(x_train)
            print('epoch:%d,  train_acc:%f' % (i, train_acc))
            val_loss, val_acc = session.run([loss, acc], feed_dict={input_x: x_val[0:256], input_y: y_val[0:256], keep_prob: 0.75})
            print('epoch:%d,  val_acc:%f' % (i, val_acc))
            if val_acc > train_acc:
                saver.save(session, save_dir+'my_rnn_model.ckpt')

    if state_flag == 'test':
        ckpt = tf.train.get_checkpoint_state(save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(session, ckpt.model_checkpoint_path)

            x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, seq_length)
            print('len(x_test)', len(x_test))
            y_test_labels = np.argmax(y_test, 1)
            batch_size2 = 128
            data_len = len(x_test)
            num_batch = int((data_len - 1) / batch_size2) + 1

            y_pred_labels = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
            for i in range(num_batch):  # 逐批次处理
                start_id = i * batch_size2
                end_id = min((i + 1) * batch_size2, data_len)
                y_pred_labels[start_id: end_id] = session.run(y_pred, feed_dict={input_x: x_test[start_id: end_id], keep_prob: 0.75})

            print("Precision, Recall and F1-Score...")
            print(metrics.classification_report(y_test_labels, y_pred_labels))

            # 混淆矩阵
            print("Confusion Matrix...")
            cm = metrics.confusion_matrix(y_test_labels, y_pred_labels)
            print(cm)

    if state_flag == 'predict':
        ckpt = tf.train.get_checkpoint_state(save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(session, ckpt.model_checkpoint_path)

            x_pred = process_predict_file(pred_dir, word_to_id, seq_length)
            batch_size3 = 128
            data_len = len(x_pred)
            num_batch = int((data_len - 1) / batch_size3) + 1

            y_pred_labels = np.zeros(shape=len(x_pred), dtype=np.int32)  # 保存预测结果
            for i in range(num_batch):  # 逐批次处理
                start_id = i * batch_size3
                end_id = min((i + 1) * batch_size3, data_len)
                y_pred_labels[start_id: end_id] = session.run(y_pred, feed_dict={input_x: x_pred[start_id: end_id], keep_prob: 0.75})

            print('y_pred_labels', y_pred_labels)
            categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
            for id in range(len(y_pred_labels)):
                print('第 %d 条新闻的类别是:  %s' % (id+1, categories[y_pred_labels[id]]))


save_dir = 'text_classify_rnn_cnn/my_ckpt/'
pred_dir = 'text_classify_rnn_cnn/pred_file_input.txt'
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)  # vocab.txt already generated
train_model('train')

The process of generating embedding is as follows: 

cnews_loader.py

# coding: utf-8

import sys
from collections import Counter

import numpy as np
import tensorflow.contrib.keras as kr

if sys.version_info[0] > 2:
    is_py3 = True
else:
    # reload(sys)
    sys.setdefaultencoding("utf-8")
    is_py3 = False


def native_word(word, encoding='utf-8'):
    """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
    if not is_py3:
        return word.encode(encoding)
    else:
        return word


def native_content(content):
    if not is_py3:
        return content.decode('utf-8')
    else:
        return content


def open_file(filename, mode='r'):
    # mode: 'r' or 'w' for read or write
    if is_py3:
        return open(filename, mode, encoding='utf-8', errors='ignore')
    else:
        return open(filename, mode)


def read_file(filename):
    """读取文件数据"""
    contents, labels = [], []
    with open_file(filename) as f:
        for line in f:
            try:
                label, content = line.strip().split('\t')
                if content:
                    contents.append(list(native_content(content)))
                    labels.append(native_content(label))
            except:
                pass
    return contents, labels


# content, label = read_file('D:/DL/lstm_blog_case/text_classify_rnn_cnn/data/cnews/cnews.train.txt')
# print("content",content[0:2])  #    content [['马', '晓', '旭', '意', '外', '受', '伤'....]]
# print("label",label[0:10])    #   ['体育', '体育'.....]


def build_vocab(train_dir, vocab_dir, vocab_size=5000):
    # generate global vocabulary words list
    #
    data_train, _ = read_file(train_dir)

    all_data = []
    for content in data_train:
        all_data.extend(content)

    counter = Counter(all_data)
    count_pairs = counter.most_common(vocab_size - 1)
    # 每个词出现的频次
    words, _ = list(zip(*count_pairs))
    #print(list(zip(*count_pairs)))
    # 添加一个 <PAD> 来将所有文本pad为同一长度
    words = ['<PAD>'] + list(words)
    open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')


build_vocab('D://DL//lstm_blog_case//text_classify_rnn_cnn//data//cnews//cnews.train.txt',
            'D://DL//lstm_blog_case//text_classify_rnn_cnn//write.txt', vocab_size=5000)


def read_vocab(vocab_dir):
    """读取词汇表"""
    # words = open_file(vocab_dir).read().strip().split('\n')
    with open_file(vocab_dir) as fp:
        # 如果是py2 则每个值都转化为unicode
        words = [native_content(_.strip()) for _ in fp.readlines()]
    word_to_id = dict(zip(words, range(len(words))))
    return words, word_to_id


def read_category():
    """读取分类目录,固定"""
    categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

    categories = [native_content(x) for x in categories]

    cat_to_id = dict(zip(categories, range(len(categories))))

    return categories, cat_to_id


def to_words(content, words):
    """将id表示的内容转换为文字"""
    return ''.join(words[x] for x in content)


def process_file(filename, word_to_id, cat_to_id, max_length=600):
    """convert all the words in the file into id (number)"""
    contents, labels = read_file(filename)

    data_id, label_id = [], []
    for i in range(len(contents)):
        data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
        label_id.append(cat_to_id[labels[i]])

    # 使用keras提供的pad_sequences来将文本pad为固定长度
    x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
    y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示

    return x_pad, y_pad


def batch_iter(x, y, batch_size=64):
    """生成批次数据"""
    data_len = len(x)
    num_batch = int((data_len - 1) / batch_size) + 1

    indices = np.random.permutation(np.arange(data_len))
    x_shuffle = x[indices]
    y_shuffle = y[indices]

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]


# add my own function in prediction

def read_predcit_file(filename):
    """读取文件数据"""
    contents, labels = [], []
    with open_file(filename) as f:
        for line in f:
            try:
                content = line.strip().split('\t')
                if content:
                    contents.append(list(native_content(content)))
            except:
                pass
    return contents


def process_predict_file(filename, word_to_id, max_length=600):
    """将文件转换为id表示"""
    contents = read_predcit_file(filename)
    data_id = []
    for i in range(len(contents)):
        data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
    # 使用keras提供的pad_sequences来将文本pad为固定长度
    x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)

    return x_pad

It assume the vocabulary_word_list is already generated.

1. Converting the words into its index in the vocabulary_word_list to represent the data.

For instance, each piece of note will be represented as this,  [1, 34, 32, 1006, 200, ....]

By the way, the data is in Chinese, so the each token is just a Chinese word.

2. Using a package in Keras

import tensorflow.contrib.keras as kr


def process_file(filename, word_to_id, cat_to_id, max_length=600):
    """convert all the words in the file into id (number)"""
    contents, labels = read_file(filename)

    data_id, label_id = [], []
    for i in range(len(contents)):
        data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
        label_id.append(cat_to_id[labels[i]])

    # 使用keras提供的pad_sequences来将文本pad为固定长度
    x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
    y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示

    return x_pad, y_pad

If the max_length is set 800, zeros will be padded. 

[[   0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0  387 1197 2173  215  110  264  814  219   16  725  981 2629
   245 1645   14 1190  231  110  808 2026 2784   43  581  224   98 2345
   470 1190 1609  659  188  209   34   32 1609  659    1   16  725  244
   716  153  165    8 1309 1362 1190  298    2 1061 1478    3  105   70
    49   12   62   61  951   91  164    1   16  725  244    2   62  234
   977  851  333  144  264   32   14 1190    2  900 1478    1  245 1645
   100   61  244  200   40  183 1181 1243   10   18   55   52  883   56
  1191 1191  246   57    3   49   12   62   20  951   12    7  164    1
    16  725  244    6  725  134   11  169  110   57  977  851    2   27
   475    1  125   56    5 1675 1327 1327    2    1  405  554  468  188
   336  185  143  125   61  951 1609  659   56    8   14 1190    1  108
  1135  121  244 1564   20  951    2  977  851  194  165    8  264   32
   330  464  900 1478    3   61  951   91  164    1  143  157  244 1296
   271  977  851   57   27    1   14 1190  167   63   61   10  385   22
   122   27    1   80  505 1055 1342  165    8  886   61   34    2  215
   730    3 1551  205  538    4  538    2  608  144    1  157  244   72
   404   10  143  125   61  951    2  644   36  977  851    1   18   55
    52  883   66  202   10    1  125  405  165    8  330  464  490  121
     2 1278  554    1   21   10  232  797  157  200   40    1   16  725
   244  526  126   11  853  143  125    2  977  851    1  117  244  371
   534 1404  267 1070  832    3    6 1190   11  977  851   39  589  157
   244   34   84  194    9    5  421  217 1712 1993  182    1  108    6
   725  492   35  534   86   72  404  100   65    1  117  244  326   68
    23 1950 1052   24   10    3    6 1609  659   71   59    4  309    2
   977  851    1   16  725  244  321  332   41  232  297   54    8    2
   157  200    9  333   33   54  215  110    2  814  931  162  477   31
   831  120  593  247  253   81  212    1  166  158   19    4 1015  576
   718  239  977  851  264  814   15  718  239  242 1569  151 1763  931
     2   33   54  230  244 1564  358    6   10  161  143  155   41    2
   172  555    3   80 1296  271 1609  659  100   59    1   11   59  699
  1361 2409 1584   56    4  339  165    8  977  851    1 1361 2409 1584
     5  105   70   18  105   62    6  132  744 1533   20   10  242 1569
     1  166  158   46  165    8  493  115   18   77   62  654 1269  257
   703  470    2  422  252  212    3  244 1564  639  845   84    1 1361
  2409 1584  194  165    8   33   54   37 1808 1758  721    1  108   21
    10  324  117  220  503    1   19  173  125   96    5  219   44 1084
  1012  961  365    1  151  242 1569 1464  611  121   10  100   59  333
  1434  562  977  851    3  269   60    8   10 1361 2409 1584   19   22
   644  139    1  166  158   16  725  244   39 1190   11  977  851   56
   336   68  170  316 1619 1524    1  109   41    5  747  169  157  200
    40  264 1931   80  545   37  242 1569    1  345   50  282  283 1040
   769  200    3   80  291  589  244  200  387 1197 2173    6  422  252
   212   11  264  814  293  610  245 1462  725  492    2   65  228    1
    46  219    6 1609  659    2   16  725  244   54    6  231  110  981
  2629    1   23  977  851   11    9  631 3222 4015  244  200   40   41
   483  215  111   69    1   28   40   51    9   45  333   33   19  184
     2  182  162   10    3   24    4  172  152   69   13  200  127  185
     3  175  132  744   32 1609  659    1 1190  298    4  533 1333  546
   205   16  725  244    1   23   46 2132   10    1   28   40  531   32
   986  662 1190   56   61   32  986  662    1    6  132  744  385  130
   977  851   88  230   14 1190  305 3011   30   10    1  165  265   32
    34 1609  659  433 1635   32   19  187  182  162    3   24    4  172
    16  725  157  200   46   39 1190  298    2   23  808 2026   24    8
   156    9  311    3]]
发布了18 篇原创文章 · 获赞 5 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/Zhou_Dao/article/details/103747103