深い学習ニューロンモデル7MNIST練習手書き数字認識問題への応用

特定のコードを参照してくださいgithubのを

1つの問題

ここに画像を挿入説明

2つのコーディング慣行

2.1データのロード

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist=input_data.read_data_sets("MNIST/",one_hot=True)

tensorflowそれなりに遅いダウンロード場合は、データ・セット・MNISTファイルにダウンロードすることができます

print("训练集数量:",mnist.train.num_examples,',验证集数量:',mnist.validation.num_examples,',测试集数量:',mnist.test.num_examples)

ここに画像を挿入説明

2.2モデル変数

XおよびYは、プレースホルダを定義します

# mnist中每张图片共有28*28=784个像素点
x=tf.placeholder(tf.float32,[None,784],name="X")
# 0-9一共10个数字--10个类别
y=tf.placeholder(tf.float32,[None,10],name="Y")

変数の作成

W=tf.Variable(tf.random_normal([784,10]),name='W')
b=tf.Variable(tf.zeros([10]),name="b")

単一ニューロンニューラルネットワークの構築

forward=tf.matmul(x,W)+b# 向前计算

決定は機能を追加して、これらの機能を置かれたとき、この判断はあるに:私たちは、多くの場合、マルチ分類タスクを扱っているソフトマックス回帰について使用ソフトマックス回帰モデルのソフトマックスは、作品の種類ごとに確率を推定しますが必要です確率クラス

pred=tf.nn.softmax(forward)

2.3トレーニングモデルハイパー

train_epochs=100
batch_size=50
total_batch=int(mnist.train.num_examples/batch_size)
display_step=1
learning_rate=0.02

2.4定義モデル

定義された損失関数

# 定义交叉熵损失函数
loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

定義のオプティマイザ

#梯度下降优化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

正確なレートを定義します

# 检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
# 准确率,将布尔值转化为浮点数,并计算平均值
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)

2.5モデルのトレーニング

# 开始训练
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys=mnist.train.next_batch(batch_size)#读取数据
        sess.run(optimizer,feed_dict={x:xs,y:ys})# 执行批次训练
        # total_batch个批次训练完成后,使用验证数据计算误差和准确率;验证集没有分批
        loss,acc=sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
        #打印训练过程中信息
        if(epoch+1)%display_step==0:
            print("Train Epoch:","%02d"%(epoch+1),"Loss=","{:.9f}".format(loss),\
                  'Accuracy=',"{:4f}".format(acc))
print("Train Finished")

ここに画像を挿入説明

3評価モデル

accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("test accuracy:",accu_test)

ここに画像を挿入説明

accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("test accuracy:",accu_test)

ここに画像を挿入説明

accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("test accuracy:",accu_validation)

ここに画像を挿入説明

4アプリケーションモデル

#由于pred预测结果是0ne-hot编码格式,所以转换为0-9的数字
prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
# 查看结果
prediction_result[0:10]

ここに画像を挿入説明
可視化機能が定義されています

import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images, #图像列表
                                 labels, # 标签列表
                                 prediction, # 预测值列表
                                 index, #从第index个开始显示
                                 num=10): # 缺省一次显示10幅
    fig=plt.gcf() # 获取当前图标
    fig.set_size_inches(10,12)# 1英寸等于2.54cm
    if num >25: # 最多显示25个图
        num=25
    for i in range(0,num):
        ax=plt.subplot(5,5,i+1)#获取当前要处理的子图
        ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')
        title='label='+str(np.argmax([labels[index]]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[index])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        index+=1
    plt.show
              
plot_images_labels_prediction(mnist.test.images,
                             mnist.test.labels,
                             prediction_result,10,10)

ここに画像を挿入説明

結果の5ディスカッション

次のハイパーパラメータ、88.4パーセントの正解率で

train_epochs=100
batch_size=100
total_batch=int(mnist.train.num_examples/batch_size)
display_step=1
learning_rate=0.01

経時的に次のようにパラメータがある場合、90.7パーセントの正解率は、条件を満たし

train_epochs=100
batch_size=50
total_batch=int(mnist.train.num_examples/batch_size)
display_step=1
learning_rate=0.02
公開された284元の記事 ウォン称賛19 ビュー20000 +

おすすめ

転載: blog.csdn.net/weixin_39289876/article/details/104740027