利用Tensorflow构建RNN实现垃圾邮件分类

1 导入库

import os
import re
import io
import requests
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from zipfile import ZipFile

2 加载数据

data_dir = 'temp'
data_file = 'text_data.txt'
if not os.path.exists(data_dir): # 查看temp文件夹是否存在
    os.makedirs(data_dir)
if not os.path.isfile(os.path.join(data_dir, data_file)): # 查看text_data.txt是否存在
    zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
    r = requests.get(zip_url)
    z = ZipFile(io.BytesIO(r.content))
    file = z.read('SMSSpamCollection')
    text_data = file.decode()
    text_data = text_data.encode('ascii', errors='ignore')
    text_data = text_data.decode().split('\n')
    with open(os.path.join(data_dir, data_file), 'w') as f:
        for x in text_data:
            f.write(x + '\n') # 将每个邮件分段
else:
    with open(os.path.join(data_dir, data_file), 'r') as f:
        text_data = f.readlines() # 直接将所有的邮件读取成列表的形式
text_data = [x.split('\t') for x in text_data[0:-1]] # 最后一个为空行,所以应该去除
text_data_target = [x[0] for x in text_data] # 标签
text_data_train = [x[1] for x in text_data] # 邮件

3 数据清洗

def clean_text(text_string):
    text_string = re.sub(r'([^\s\w]|_[0-9])+', '', text_string) # 去除一些符号
    text_string = " ".join(text_string.split()) # 每个邮件中的每一个单词用空格分开
    text_string = text_string.lower() # 转换成小写
    return text_string
text_data_train = [clean_text(x) for x in text_data_train]
print(text_data_train[0:3])
['go until jurong point crazy available only in bugis n great world la e buffet cine there got amore wat', 'ok lar joking wif u oni', 'free entry in 2 a wkly comp to win fa cup final tkts 21st may 2005 text fa to 87121 to receive entry questionstd txt ratetcs apply 08452810075over18s']

4 将文本转换成索引列表

max_sequence_length = 25
min_word_frequency = 10
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(max_sequence_length, min_word_frequency) # 将每个邮件转换成对应单词的索引表示
text_processed = np.array(list(vocab_processor.fit_transform(text_data_train))) # 每个邮件转换成索引
print(text_processed[0:3])
WARNING:tensorflow:From <ipython-input-4-35bf3897e6d9>:3: VocabularyProcessor.__init__ (from tensorflow.contrib.learn.python.learn.preprocessing.text) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tensorflow/transform or tf.data.
WARNING:tensorflow:From D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\preprocessing\text.py:154: CategoricalVocabulary.__init__ (from tensorflow.contrib.learn.python.learn.preprocessing.categorical_vocabulary) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tensorflow/transform or tf.data.
WARNING:tensorflow:From D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\preprocessing\text.py:170: tokenizer (from tensorflow.contrib.learn.python.learn.preprocessing.text) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tensorflow/transform or tf.data.
[[ 46 459   0 832 723 684  64   9   0  89 120 372   0 155   0   0  68  58
    0 138   0   0   0   0   0]
 [ 48 322   0 462   6   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0]
 [ 50 469   9  21   4 791 907   1 179   0   0 635   0   0 257   0  71   0
    1   0   1 325 469   0  79]]

5 打乱文本数据集

text_data_target = np.array([1 if x == 'ham' else 0 for x in text_data_target]) # 0:正常, 1:垃圾邮件
shuffled_ix = np.random.permutation(np.arange(len(text_processed))) # 随即打乱,不改变原数据
x_shuffled = text_processed[shuffled_ix]
y_shuffled = text_data_target[shuffled_ix]

6 数据集划分

ix_cutoff = int(len(text_processed) * 0.8) # 80%训练,20%测试
x_train, x_test = x_shuffled[0:ix_cutoff], x_shuffled[ix_cutoff:]
y_train, y_test = y_shuffled[0:ix_cutoff], y_shuffled[ix_cutoff:]
vocab_size = len(vocab_processor.vocabulary_) # 词表大小
print("Vocabulary Size: {}".format(vocab_size))
print("80-20 Train Test Split: {} -- {}".format(len(y_train), len(y_test)))
Vocabulary Size: 954
80-20 Train Test Split: 4459 -- 1115

7 占位符

x_data = tf.placeholder(shape=[None,max_sequence_length], dtype=tf.int32) # 单词索引
y_output = tf.placeholder(shape=[None], dtype=tf.int32) # 分类结果

8 创建嵌套矩阵和查找矩阵

embedding_size = 50
embedding_mat = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0)) # 每个单词都可以用一个embedding表示,所以矩阵大小为词表大小*embedding_size
embedding_output = tf.nn.embedding_lookup(embedding_mat, x_data) # 单词的embedding

9 声明算法模型

rnn_size = 10
dropout_keep_prob = tf.placeholder(tf.float32)
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=rnn_size) # rnn_size = 10个神经元节点
output, state = tf.nn.dynamic_rnn(cell, embedding_output, dtype=tf.float32) # embedding_output为输入数据
output = tf.nn.dropout(output, dropout_keep_prob)

10 预测

output = tf.transpose(output, [1, 0, 2]) # output有三个维度0,1,2,分别为b[batch_size,step,input_size],第1维度(step)的最后一行output[:,-1,:]即为最后的输出
last = tf.gather(output, int(output.get_shape()[0])-1)

11 全连接层

weight = tf.Variable(tf.truncated_normal([rnn_size, 2], mean=0, stddev=0.1)) # 权重为rnn_size * label数
bias = tf.Variable(tf.constant(0.1, shape=[2]))
logits_out = tf.add(tf.matmul(last, weight), bias)
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_out, labels=y_output)
loss = tf.reduce_mean(losses)

12 准确率

predict = tf.argmax(tf.nn.softmax(logits_out),1) # 预测值
accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, tf.cast(y_output, tf.int64)), tf.float32)) # 准确率

13 优化器

sess = tf.Session()
learning_rate = 0.0005
optimizer = tf.train.RMSPropOptimizer(learning_rate)
train_step = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess.run(init)
D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gradients_impl.py:108: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

14 训练

batch_size = 250
epochs = 100
train_loss = []
test_loss = []
train_acc = []
test_acc = []
for epoch in range(epochs):
    # 训练
    shuffled_ix = np.random.permutation(np.arange(len(x_train)))
    x_train = x_train[shuffled_ix]
    y_train = y_train[shuffled_ix]
    num_batches = int(len(x_train) / batch_size) + 1
    loss_a = 0
    accuracy_a = 0
    for i in range(num_batches):
        min_ix = i * batch_size
        max_ix = np.min([len(x_train), (i+1)*batch_size])
        x_trian_batch = x_train[min_ix:max_ix]
        y_train_batch = y_train[min_ix:max_ix]
        trian_dict = {
    
    x_data:x_trian_batch, y_output:y_train_batch, dropout_keep_prob:0.5}
        _, loss_b, accuracy_b = sess.run([train_step, loss, accuracy], feed_dict=trian_dict)
        #loss_a += loss_b
        #accuracy_a += accuracy_b
    train_loss.append(loss_b)
    train_acc.append(accuracy_b)
    # 验证
    test_dict = {
    
    x_data:x_test, y_output:y_test, dropout_keep_prob:1.0}
    temp_test_loss, temp_test_acc = sess.run([loss, accuracy], feed_dict=test_dict)
    test_loss.append(temp_test_loss)
    test_acc.append(temp_test_acc)
    print("Epoch: {}, Train_loss: {}, Train_acc: {} --- Test_loss: {}, Test_acc: {}".format(epoch+1, loss_b, accuracy_b, temp_test_loss, temp_test_acc))
Epoch: 1, Train_loss: 0.7470735311508179, Train_acc: 0.35885167121887207 --- Test_loss: 0.7354287505149841, Test_acc: 0.17399102449417114
Epoch: 2, Train_loss: 0.6956271529197693, Train_acc: 0.4784688949584961 --- Test_loss: 0.7028129696846008, Test_acc: 0.17309416830539703
Epoch: 3, Train_loss: 0.6544756293296814, Train_acc: 0.6746411323547363 --- Test_loss: 0.660607635974884, Test_acc: 0.8349775671958923
Epoch: 4, Train_loss: 0.6046224236488342, Train_acc: 0.7846890091896057 --- Test_loss: 0.6118777990341187, Test_acc: 0.8457399010658264
Epoch: 5, Train_loss: 0.5662848949432373, Train_acc: 0.8564593195915222 --- Test_loss: 0.5601459741592407, Test_acc: 0.8466367721557617
Epoch: 6, Train_loss: 0.5476279854774475, Train_acc: 0.8086124658584595 --- Test_loss: 0.511917769908905, Test_acc: 0.847533643245697
Epoch: 7, Train_loss: 0.46341803669929504, Train_acc: 0.8708133697509766 --- Test_loss: 0.47170114517211914, Test_acc: 0.8484305143356323
Epoch: 8, Train_loss: 0.43886634707450867, Train_acc: 0.8708133697509766 --- Test_loss: 0.44291701912879944, Test_acc: 0.8484305143356323
Epoch: 9, Train_loss: 0.5139765739440918, Train_acc: 0.7846890091896057 --- Test_loss: 0.42077308893203735, Test_acc: 0.847533643245697
Epoch: 10, Train_loss: 0.38777387142181396, Train_acc: 0.8851674795150757 --- Test_loss: 0.4064062237739563, Test_acc: 0.8457399010658264
Epoch: 11, Train_loss: 0.39114782214164734, Train_acc: 0.8468899726867676 --- Test_loss: 0.39904409646987915, Test_acc: 0.8457399010658264
Epoch: 12, Train_loss: 0.35100817680358887, Train_acc: 0.89952152967453 --- Test_loss: 0.3950567841529846, Test_acc: 0.847533643245697
Epoch: 13, Train_loss: 0.33624348044395447, Train_acc: 0.8947368264198303 --- Test_loss: 0.3925701677799225, Test_acc: 0.8484305143356323
Epoch: 14, Train_loss: 0.400579035282135, Train_acc: 0.8468899726867676 --- Test_loss: 0.3902212977409363, Test_acc: 0.8484305143356323
Epoch: 15, Train_loss: 0.36475613713264465, Train_acc: 0.8612440228462219 --- Test_loss: 0.38799500465393066, Test_acc: 0.8493273258209229
Epoch: 16, Train_loss: 0.4579482674598694, Train_acc: 0.8181818127632141 --- Test_loss: 0.38558176159858704, Test_acc: 0.8538116812705994
Epoch: 17, Train_loss: 0.42817819118499756, Train_acc: 0.8612440228462219 --- Test_loss: 0.38300079107284546, Test_acc: 0.8538116812705994
Epoch: 18, Train_loss: 0.3685735762119293, Train_acc: 0.8708133697509766 --- Test_loss: 0.3801892101764679, Test_acc: 0.8565022349357605
Epoch: 19, Train_loss: 0.32959237694740295, Train_acc: 0.8947368264198303 --- Test_loss: 0.3774225115776062, Test_acc: 0.8573991060256958
Epoch: 20, Train_loss: 0.39528053998947144, Train_acc: 0.8468899726867676 --- Test_loss: 0.37414032220840454, Test_acc: 0.8573991060256958
Epoch: 21, Train_loss: 0.38511836528778076, Train_acc: 0.8564593195915222 --- Test_loss: 0.37055113911628723, Test_acc: 0.8582959771156311
Epoch: 22, Train_loss: 0.4212755858898163, Train_acc: 0.8468899726867676 --- Test_loss: 0.3664109408855438, Test_acc: 0.8600896596908569
Epoch: 23, Train_loss: 0.3296143412590027, Train_acc: 0.8851674795150757 --- Test_loss: 0.3622855246067047, Test_acc: 0.8600896596908569
Epoch: 24, Train_loss: 0.36846017837524414, Train_acc: 0.8851674795150757 --- Test_loss: 0.3570423722267151, Test_acc: 0.8618834018707275
Epoch: 25, Train_loss: 0.3019673228263855, Train_acc: 0.9138755798339844 --- Test_loss: 0.3511543273925781, Test_acc: 0.8627802729606628
Epoch: 26, Train_loss: 0.31075412034988403, Train_acc: 0.9043062329292297 --- Test_loss: 0.34401190280914307, Test_acc: 0.8672645688056946
Epoch: 27, Train_loss: 0.358598917722702, Train_acc: 0.8755980730056763 --- Test_loss: 0.3346104025840759, Test_acc: 0.8735426068305969
Epoch: 28, Train_loss: 0.32526490092277527, Train_acc: 0.8851674795150757 --- Test_loss: 0.323251873254776, Test_acc: 0.878923773765564
Epoch: 29, Train_loss: 0.3157275915145874, Train_acc: 0.89952152967453 --- Test_loss: 0.30881568789482117, Test_acc: 0.8825111985206604
Epoch: 30, Train_loss: 0.29228392243385315, Train_acc: 0.9090909361839294 --- Test_loss: 0.2899859845638275, Test_acc: 0.8914798498153687
Epoch: 31, Train_loss: 0.3699236810207367, Train_acc: 0.8468899726867676 --- Test_loss: 0.2690330743789673, Test_acc: 0.9058296084403992
Epoch: 32, Train_loss: 0.23446135222911835, Train_acc: 0.9377990365028381 --- Test_loss: 0.24655458331108093, Test_acc: 0.915695071220398
Epoch: 33, Train_loss: 0.1925920844078064, Train_acc: 0.9282296895980835 --- Test_loss: 0.22722524404525757, Test_acc: 0.9255605340003967
Epoch: 34, Train_loss: 0.24157537519931793, Train_acc: 0.9234449863433838 --- Test_loss: 0.22854343056678772, Test_acc: 0.9372197389602661
Epoch: 35, Train_loss: 0.1614789217710495, Train_acc: 0.9665071964263916 --- Test_loss: 0.20862336456775665, Test_acc: 0.9300448298454285
Epoch: 36, Train_loss: 0.15638592839241028, Train_acc: 0.9569377899169922 --- Test_loss: 0.19794519245624542, Test_acc: 0.9363228678703308
Epoch: 37, Train_loss: 0.17246992886066437, Train_acc: 0.9425837397575378 --- Test_loss: 0.18526984751224518, Test_acc: 0.9488789439201355
Epoch: 38, Train_loss: 0.15459716320037842, Train_acc: 0.9665071964263916 --- Test_loss: 0.1810169667005539, Test_acc: 0.949775755405426
Epoch: 39, Train_loss: 0.14387375116348267, Train_acc: 0.9760765433311462 --- Test_loss: 0.17268356680870056, Test_acc: 0.9524663686752319
Epoch: 40, Train_loss: 0.13036970794200897, Train_acc: 0.9569377899169922 --- Test_loss: 0.16984421014785767, Test_acc: 0.9533632397651672
Epoch: 41, Train_loss: 0.14434468746185303, Train_acc: 0.9617224931716919 --- Test_loss: 0.17072317004203796, Test_acc: 0.9515694975852966
Epoch: 42, Train_loss: 0.19085589051246643, Train_acc: 0.9569377899169922 --- Test_loss: 0.16185931861400604, Test_acc: 0.9533632397651672
Epoch: 43, Train_loss: 0.1424720585346222, Train_acc: 0.9760765433311462 --- Test_loss: 0.16070027649402618, Test_acc: 0.957847535610199
Epoch: 44, Train_loss: 0.163960263133049, Train_acc: 0.9521530866622925 --- Test_loss: 0.17735721170902252, Test_acc: 0.926457405090332
Epoch: 45, Train_loss: 0.1660783886909485, Train_acc: 0.9473684430122375 --- Test_loss: 0.16055651009082794, Test_acc: 0.9488789439201355
Epoch: 46, Train_loss: 0.06969129294157028, Train_acc: 0.9952152967453003 --- Test_loss: 0.15094122290611267, Test_acc: 0.9605380892753601
Epoch: 47, Train_loss: 0.1361570805311203, Train_acc: 0.9760765433311462 --- Test_loss: 0.14727219939231873, Test_acc: 0.9605380892753601
Epoch: 48, Train_loss: 0.11492782831192017, Train_acc: 0.9569377899169922 --- Test_loss: 0.14613188803195953, Test_acc: 0.9605380892753601
Epoch: 49, Train_loss: 0.10670477896928787, Train_acc: 0.9712918400764465 --- Test_loss: 0.1442699283361435, Test_acc: 0.9641255736351013
Epoch: 50, Train_loss: 0.14280426502227783, Train_acc: 0.9665071964263916 --- Test_loss: 0.1441742181777954, Test_acc: 0.9641255736351013
Epoch: 51, Train_loss: 0.09684259444475174, Train_acc: 0.9760765433311462 --- Test_loss: 0.14695033431053162, Test_acc: 0.9542601108551025
Epoch: 52, Train_loss: 0.08919525146484375, Train_acc: 0.9760765433311462 --- Test_loss: 0.14686116576194763, Test_acc: 0.9524663686752319
Epoch: 53, Train_loss: 0.09268508106470108, Train_acc: 0.9904305934906006 --- Test_loss: 0.13317707180976868, Test_acc: 0.9677129983901978
Epoch: 54, Train_loss: 0.08003458380699158, Train_acc: 0.9712918400764465 --- Test_loss: 0.13285662233829498, Test_acc: 0.9677129983901978
Epoch: 55, Train_loss: 0.05673607066273689, Train_acc: 0.9856459498405457 --- Test_loss: 0.13103769719600677, Test_acc: 0.9668161273002625
Epoch: 56, Train_loss: 0.11883581429719925, Train_acc: 0.9617224931716919 --- Test_loss: 0.13139890134334564, Test_acc: 0.9695067405700684
Epoch: 57, Train_loss: 0.06191867217421532, Train_acc: 0.9904305934906006 --- Test_loss: 0.17760926485061646, Test_acc: 0.9327354431152344
Epoch: 58, Train_loss: 0.08464467525482178, Train_acc: 0.9712918400764465 --- Test_loss: 0.12995852530002594, Test_acc: 0.9695067405700684
Epoch: 59, Train_loss: 0.13568198680877686, Train_acc: 0.9665071964263916 --- Test_loss: 0.12945900857448578, Test_acc: 0.9686098694801331
Epoch: 60, Train_loss: 0.08627452701330185, Train_acc: 0.9760765433311462 --- Test_loss: 0.129767507314682, Test_acc: 0.9686098694801331
Epoch: 61, Train_loss: 0.11554665863513947, Train_acc: 0.9760765433311462 --- Test_loss: 0.12860263884067535, Test_acc: 0.9713004231452942
Epoch: 62, Train_loss: 0.07882294058799744, Train_acc: 0.9856459498405457 --- Test_loss: 0.12875868380069733, Test_acc: 0.9677129983901978
Epoch: 63, Train_loss: 0.11422040313482285, Train_acc: 0.9712918400764465 --- Test_loss: 0.12917351722717285, Test_acc: 0.9713004231452942
Epoch: 64, Train_loss: 0.07914420962333679, Train_acc: 0.9760765433311462 --- Test_loss: 0.12913142144680023, Test_acc: 0.9695067405700684
Epoch: 65, Train_loss: 0.07433979213237762, Train_acc: 0.9760765433311462 --- Test_loss: 0.13048280775547028, Test_acc: 0.9704036116600037
Epoch: 66, Train_loss: 0.07714660465717316, Train_acc: 0.9712918400764465 --- Test_loss: 0.13510984182357788, Test_acc: 0.9623318314552307
Epoch: 67, Train_loss: 0.07190703600645065, Train_acc: 0.980861246585846 --- Test_loss: 0.12696626782417297, Test_acc: 0.9677129983901978
Epoch: 68, Train_loss: 0.09927602857351303, Train_acc: 0.980861246585846 --- Test_loss: 0.12567484378814697, Test_acc: 0.9686098694801331
Epoch: 69, Train_loss: 0.08709713816642761, Train_acc: 0.9856459498405457 --- Test_loss: 0.1345020979642868, Test_acc: 0.9614349603652954
Epoch: 70, Train_loss: 0.06702243536710739, Train_acc: 0.980861246585846 --- Test_loss: 0.12303993850946426, Test_acc: 0.9686098694801331
Epoch: 71, Train_loss: 0.15951356291770935, Train_acc: 0.9617224931716919 --- Test_loss: 0.1306402087211609, Test_acc: 0.963228702545166
Epoch: 72, Train_loss: 0.04567800089716911, Train_acc: 0.9904305934906006 --- Test_loss: 0.12566307187080383, Test_acc: 0.9659192562103271
Epoch: 73, Train_loss: 0.08264987170696259, Train_acc: 0.9904305934906006 --- Test_loss: 0.12613965570926666, Test_acc: 0.9668161273002625
Epoch: 74, Train_loss: 0.07273389399051666, Train_acc: 0.9856459498405457 --- Test_loss: 0.12133991718292236, Test_acc: 0.9695067405700684
Epoch: 75, Train_loss: 0.11598026752471924, Train_acc: 0.9665071964263916 --- Test_loss: 0.12171116471290588, Test_acc: 0.9686098694801331
Epoch: 76, Train_loss: 0.06833338737487793, Train_acc: 0.9904305934906006 --- Test_loss: 0.12302467972040176, Test_acc: 0.9659192562103271
Epoch: 77, Train_loss: 0.05746075138449669, Train_acc: 0.980861246585846 --- Test_loss: 0.12428522109985352, Test_acc: 0.9650224447250366
Epoch: 78, Train_loss: 0.03369761258363724, Train_acc: 0.9952152967453003 --- Test_loss: 0.12469173222780228, Test_acc: 0.9641255736351013
Epoch: 79, Train_loss: 0.09764815866947174, Train_acc: 0.9856459498405457 --- Test_loss: 0.11996684223413467, Test_acc: 0.9695067405700684
Epoch: 80, Train_loss: 0.03072713129222393, Train_acc: 0.9904305934906006 --- Test_loss: 0.12080205231904984, Test_acc: 0.9677129983901978
Epoch: 81, Train_loss: 0.06581427156925201, Train_acc: 0.9856459498405457 --- Test_loss: 0.1175323873758316, Test_acc: 0.9704036116600037
Epoch: 82, Train_loss: 0.09099692106246948, Train_acc: 0.980861246585846 --- Test_loss: 0.11873693019151688, Test_acc: 0.9721972942352295
Epoch: 83, Train_loss: 0.10331112891435623, Train_acc: 0.9856459498405457 --- Test_loss: 0.12006927281618118, Test_acc: 0.9686098694801331
Epoch: 84, Train_loss: 0.07458710670471191, Train_acc: 0.9904305934906006 --- Test_loss: 0.12819139659404755, Test_acc: 0.963228702545166
Epoch: 85, Train_loss: 0.06534530967473984, Train_acc: 0.9856459498405457 --- Test_loss: 0.14707119762897491, Test_acc: 0.957847535610199
Epoch: 86, Train_loss: 0.11099092662334442, Train_acc: 0.9760765433311462 --- Test_loss: 0.12355218082666397, Test_acc: 0.9677129983901978
Epoch: 87, Train_loss: 0.06861938536167145, Train_acc: 0.9856459498405457 --- Test_loss: 0.11979269981384277, Test_acc: 0.9739910364151001
Epoch: 88, Train_loss: 0.08237955719232559, Train_acc: 0.9856459498405457 --- Test_loss: 0.14108514785766602, Test_acc: 0.963228702545166
Epoch: 89, Train_loss: 0.06249811500310898, Train_acc: 0.9904305934906006 --- Test_loss: 0.12163014709949493, Test_acc: 0.9739910364151001
Epoch: 90, Train_loss: 0.06678888201713562, Train_acc: 0.9856459498405457 --- Test_loss: 0.11736054718494415, Test_acc: 0.9739910364151001
Epoch: 91, Train_loss: 0.06128064915537834, Train_acc: 0.9856459498405457 --- Test_loss: 0.1292845606803894, Test_acc: 0.9650224447250366
Epoch: 92, Train_loss: 0.08472830802202225, Train_acc: 0.980861246585846 --- Test_loss: 0.11756479740142822, Test_acc: 0.9748879075050354
Epoch: 93, Train_loss: 0.15842005610466003, Train_acc: 0.9712918400764465 --- Test_loss: 0.12553487718105316, Test_acc: 0.9668161273002625
Epoch: 94, Train_loss: 0.021188249811530113, Train_acc: 0.9952152967453003 --- Test_loss: 0.11657082289457321, Test_acc: 0.9748879075050354
Epoch: 95, Train_loss: 0.027060668915510178, Train_acc: 0.9952152967453003 --- Test_loss: 0.11732904613018036, Test_acc: 0.9739910364151001
Epoch: 96, Train_loss: 0.07324185222387314, Train_acc: 0.980861246585846 --- Test_loss: 0.14437833428382874, Test_acc: 0.9596412777900696
Epoch: 97, Train_loss: 0.021940842270851135, Train_acc: 0.9952152967453003 --- Test_loss: 0.12406717240810394, Test_acc: 0.9677129983901978
Epoch: 98, Train_loss: 0.03361573815345764, Train_acc: 0.9904305934906006 --- Test_loss: 0.11604513227939606, Test_acc: 0.9739910364151001
Epoch: 99, Train_loss: 0.08688019216060638, Train_acc: 0.9856459498405457 --- Test_loss: 0.1156449243426323, Test_acc: 0.9730941653251648
Epoch: 100, Train_loss: 0.06903882324695587, Train_acc: 0.9856459498405457 --- Test_loss: 0.11721988022327423, Test_acc: 0.9730941653251648

15 plot

epoch_seq = np.arange(1, epochs+1)
plt.plot(epoch_seq, train_loss, 'k--', label='Train Set')
plt.plot(epoch_seq, test_loss, 'r-', label='Test Set')
plt.title('Softmax Loss')
plt.xlabel('Epochs')
plt.ylabel('Softmax Loss')
plt.legend(loc='upper left')
plt.show()

plt.plot(epoch_seq, train_acc, 'k--', label='Train Set')
plt.plot(epoch_seq, test_acc, 'r-', label='Test Set')
plt.title('Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='upper left')
plt.show()

在这里插入图片描述
在这里插入图片描述

Guess you like

Origin blog.csdn.net/qq_40006058/article/details/91795634