keras超简单入门----电影评论分类

    刚开始接触到神经网络中的深度学习,用的框架是最简单的Keras框架,后端是TensorFlow。Keras框架及其简单,特别适合小白入门,所以在这里强烈推荐大家可以以Keras作为入门来学习深度学习。这里做一个简单的笔记,一是可以供自己以后查看,二是希望可以帮助到刚入门的你。

一、IMDB数据集

    这是一个公共数据集,IMDB数据集内置于Keras库中,包含50000条严重两极分化的评论,traindata和testdata各占25000条,且各包含50%的正面评论和50%的负面评论。

from keras.datasets import imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

    和Mnist数据集的导入方式一样,num_words=10000的意思是取频繁出现的前10000个单词。

   

    查看train_data[0]后发现是一串数字,这是第一个训练集的单词在单词表中的索引,即将所有评论中的单词取出来后放到一个集合里面,每个单词可以用对应的索引来表示,训练集里面就是对应的索引表示。

二、准备数据

     将数据转为二进制矩阵,直接上代码。

import numpy as np
def data2vector(data, ndim=10000):
    vec_data = np.zeros((len(data), ndim))
    for i, line in enumerate(data):
        vec_data[i, line] = 1
    return vec_data
train_data_vec = data2vector(train_data)
test_data_vec = data2vector(test_data)
train_labels = np.asarray(train_labels).astype('float32')
test_labels = np.asarray(test_labels).astype('float32')

    此代码所做的是:将每个训练集出现的单词,对应的索引位置都置为1,其他都是0。

三、构建网络

   

input = layers.Input(shape=(10000,))
hidden = layers.Dense(16, activation='relu')(input)
hidden2 = layers.Dense(16, activation='relu')(hidden)
output = layers.Dense(1, activation='sigmoid')(hidden2)

model = Model(input=input, output=output)
model.compile(optimizer='rmsprop',\
              loss='binary_crossentropy', metrics=['accuracy'])
x_val = train_data_vec[:10000]
y_val = train_labels[:10000]
pre_train_data = train_data_vec[10000:]
pre_train_labels = train_labels[10000:]
history = model.fit(pre_train_data, pre_train_labels,\
                    epochs=20, batch_size=512,\
                    validation_data=(x_val, y_val))
history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']

epochs = range(1, len(loss_values) + 1)
plt.plot(epochs, loss_values, 'bo', label='Training loss')
plt.plot(epochs, val_loss_values, 'b', label='Validation loss')
plt.show()
# model = load_model('my_model.h5')
result = model.evaluate(test_data_vec, test_labels)
print(result)
model.save('my_model.h5')

猜你喜欢

转载自blog.csdn.net/zpf123456789zpf/article/details/88258212