使用RNN进行文本分类

本文使用RNN对IMDB数据集进行情感语义分析。

1.1、输入数据

1.1.1、加载数据集

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow_datasets as tfds
import tensorflow as tf


dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,
                          as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

1.1.2、encoder

# encoder
encoder = info.features['text'].encoder
print('Vocabulary size: {}'.format(encoder.vocab_size))

sample_string = 'Hello TensorFlow.'
encoded_string = encoder.encode(sample_string)
print('Encoded string is {}'.format(encoded_string))

original_string = encoder.decode(encoded_string)
print('The original string: "{}"'.format(original_string))
# index --> word
for index in encoded_string:
    print('{} ----> {}'.format(index, encoder.decode([index])))

# print
Vocabulary size: 8185
Encoded string is [4025, 222, 6307, 2327, 4043, 2120, 7975]
The original string: "Hello TensorFlow."
4025 ----> Hell
222 ----> o 
6307 ----> Ten
2327 ----> sor
4043 ----> Fl
2120 ----> ow
7975 ----> .

1.1.3、生成批次

接下来,创建这些经过编码的字符串的批次。使用padded_batch方法将序列零填充到该批处理中最长的字符串的长度:

# 填充 padded_batch
BUFFER_SIZE = 10000
BATCH_SIZE = 64

train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)

test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)

1.2、训练模型

# 创建模型
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(encoder.vocab_size, 64),  # word embeddings
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),	# RNN层,使用Bidirectional包装RNN层
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

# 编译模型
model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])

1.3、训练模型

# 训练模型
history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset,
                    validation_steps=30)
发布了784 篇原创文章 · 获赞 90 · 访问量 44万+

猜你喜欢

转载自blog.csdn.net/wuxintdrh/article/details/103686145