Keras se da cuenta de la clasificación de los temas de noticias de Reuters

Keras se da cuenta de la clasificación de los temas de noticias de Reuters

Bibliografía: "Python Deep Learning".
¡Solo por diversión! ! !

import keras
from keras.datasets import reuters
import matplotlib.pyplot as plt
import numpy as np

Usando el backend de TensorFlow.

1 Cargue el conjunto de datos de Reuters

(train_data,train_label),(test_data,test_label)=reuters.load_data(num_words=10000)

1.1 Conjunto de entrenamiento dividido y conjunto de prueba (corte)

val_data=train_data[:1000]
train_data=train_data[1000:]
val_label=train_label[:1000]
train_label=train_label[1000:]
print(train_data.shape)
print(test_data.shape)
len(train_data[0])

(7982,)
(2246,)
626

1.2 Decodificación del índice en texto de noticias

word_index=reuters.get_word_index()
rev_word_index=dict([(value,key) for (key,value) in word_index.items()])
dec=' '.join([rev_word_index.get(i-3,'?') for i in train_data[1]])
dec

'? qtly div 19 cts vs 19 cts pago previo 15 de abril récord uno de abril reuter 3 '

2 codificación de datos (one-hot)

def one_hot(seq,dim=10000):
    res=np.zeros((len(seq),dim))
    for i,j in enumerate(seq):
        res[i,j]=1
    return res

2.1 codificación de datos

train_data=one_hot(train_data)
val_data=one_hot(val_data)
test_data=one_hot(test_data)

2.2 codificación de etiquetas

train_label=keras.utils.to_categorical(train_label)
val_label=keras.utils.to_categorical(val_label)
test_label=keras.utils.to_categorical(test_label)

3 Construya la arquitectura del modelo

model=keras.models.Sequential()
model.add(keras.layers.Dense(64,activation='relu',input_shape=(10000,)))
model.add(keras.layers.Dense(64,activation='relu'))
model.add(keras.layers.Dense(46,activation='softmax'))

4 Defina la función de optimización y pérdida

model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

5 Entrenamiento + verificación

his=model.fit(train_data,
              train_label,
              epochs=20,
              batch_size=512,
              validation_data=(val_data,val_label))
Train on 7982 samples, validate on 1000 samples
Epoch 1/20
7982/7982 [==============================] - 2s 292us/step - loss: 2.5309 - acc: 0.4959 - val_loss: 1.7227 - val_acc: 0.6110
Epoch 2/20
7982/7982 [==============================] - 1s 179us/step - loss: 1.4463 - acc: 0.6877 - val_loss: 1.3463 - val_acc: 0.7060
Epoch 3/20
7982/7982 [==============================] - 1s 170us/step - loss: 1.0953 - acc: 0.7648 - val_loss: 1.1710 - val_acc: 0.7440
Epoch 4/20
7982/7982 [==============================] - 1s 168us/step - loss: 0.8697 - acc: 0.8161 - val_loss: 1.0806 - val_acc: 0.7580
Epoch 5/20
7982/7982 [==============================] - 1s 174us/step - loss: 0.7030 - acc: 0.8472 - val_loss: 0.9834 - val_acc: 0.7820
Epoch 6/20
7982/7982 [==============================] - 2s 192us/step - loss: 0.5660 - acc: 0.8796 - val_loss: 0.9419 - val_acc: 0.8020
Epoch 7/20
7982/7982 [==============================] - 1s 181us/step - loss: 0.4578 - acc: 0.9048 - val_loss: 0.9090 - val_acc: 0.8010
Epoch 8/20
7982/7982 [==============================] - 1s 167us/step - loss: 0.3691 - acc: 0.9231 - val_loss: 0.9381 - val_acc: 0.7890
Epoch 9/20
7982/7982 [==============================] - 1s 165us/step - loss: 0.3030 - acc: 0.9312 - val_loss: 0.8910 - val_acc: 0.8090
Epoch 10/20
7982/7982 [==============================] - 1s 165us/step - loss: 0.2537 - acc: 0.9416 - val_loss: 0.9066 - val_acc: 0.8120
Epoch 11/20
7982/7982 [==============================] - 1s 168us/step - loss: 0.2182 - acc: 0.9469 - val_loss: 0.9192 - val_acc: 0.8140
Epoch 12/20
7982/7982 [==============================] - 1s 163us/step - loss: 0.1873 - acc: 0.9511 - val_loss: 0.9070 - val_acc: 0.8130
Epoch 13/20
7982/7982 [==============================] - 1s 171us/step - loss: 0.1699 - acc: 0.9523 - val_loss: 0.9364 - val_acc: 0.8070
Epoch 14/20
7982/7982 [==============================] - 1s 167us/step - loss: 0.1535 - acc: 0.9555 - val_loss: 0.9675 - val_acc: 0.8060
Epoch 15/20
7982/7982 [==============================] - 1s 172us/step - loss: 0.1389 - acc: 0.9559 - val_loss: 0.9707 - val_acc: 0.8150
Epoch 16/20
7982/7982 [==============================] - 1s 165us/step - loss: 0.1313 - acc: 0.9559 - val_loss: 1.0249 - val_acc: 0.8050
Epoch 17/20
7982/7982 [==============================] - 1s 173us/step - loss: 0.1218 - acc: 0.9582 - val_loss: 1.0294 - val_acc: 0.7960
Epoch 18/20
7982/7982 [==============================] - 1s 164us/step - loss: 0.1198 - acc: 0.9579 - val_loss: 1.0454 - val_acc: 0.8030
Epoch 19/20
7982/7982 [==============================] - 1s 166us/step - loss: 0.1139 - acc: 0.9598 - val_loss: 1.0980 - val_acc: 0.7980
Epoch 20/20
7982/7982 [==============================] - 1s 172us/step - loss: 0.1112 - acc: 0.9595 - val_loss: 1.0721 - val_acc: 0.8010

6 Procesamiento y visualización de resultados

6.1 Extraer varios parámetros de retorno

his_dict=his.history
loss=his_dict['loss']
val_loss=his_dict['val_loss']
acc=his_dict['acc']
val_acc=his_dict['val_acc']
epoch=range(1,len(loss)+1)

6.2 Dibuja la imagen de la pérdida

plt.plot(epoch,loss,'b',label='train_loss')
plt.plot(epoch,val_loss,'r',label='val_loss')
plt.title('train and validation')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()

Inserte la descripción de la imagen aquí

6.3 Dibujar la imagen de acc

plt.clf()
plt.plot(epoch,acc,'k',label='train_acc')
plt.plot(epoch,val_acc,'g',label='val_acc')
plt.title('train and validation')
plt.xlabel('epoch')
plt.ylabel('acc')
plt.legend()
plt.show()

Inserte la descripción de la imagen aquí

7 Prueba y predicción (resultados de inspección)

7.1 Prueba

test_loss,test_acc=model.evaluate(test_data,test_label)

2246/2246 [==============================] - 0 s 197us / paso

print('test_loss=',test_loss,'\ntest_acc=',test_acc)

test_loss = 1.216040284741912
test_acc = 0.778717720444884

7.2 Pronóstico

prediction=model.predict(test_data)
print('predict_result=',np.argmax(prediction[0]))
print('correct_result=',np.argmax(test_label[0]))

pronóstico_resultado = 3
resultado_correcto = 3

El resultado final está bastante bien. . .

Supongo que te gusta

Origin blog.csdn.net/weixin_41650348/article/details/108461525
Recomendado
Clasificación