minist图片多分类问题-单层神经网络

采用打单层隐藏层,使用TensorFlow框架 构建的分类神经网络 ---入门级小项目 供深度学习(TensorFlow)初学者参考

#!/usr/bin/env python
# coding: utf-8

# In[ ]:


#导入package 读取数据
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
from time import time


# In[ ]:


mnist = input_data.read_data_sets("mnist_data/",one_hot = True)


# In[ ]:


mnist.train.num_examples 


# In[ ]:


#不慌 不急 明天我将mnist的实验复现就很好 


# In[ ]:


#读入数据
mnist = input_data.read_data_sets("mnist_data/",one_hot = True)
#构建输入层
x = tf.placeholder(tf.float32,[None,784],name = "X")
y = tf.placeholder(tf.float32,[None,10],name = "Y")
#构建第一个隐藏层
H1_NN = 256
w1 = tf.Variable(tf.random_normal([784,H1_NN]))
b1 = tf.Variable(tf.zeros(H1_NN))
y1 = tf.nn.relu(tf.matmul(x,w1)+ b1) #这里为什么要使用relu函数

#构建第二个隐藏层
w2 = tf.Variable(tf.random_normal([H1_NN,10]))

b2 = tf.Variable(tf.zeros(10)) 
 
forward = tf.matmul(y1,w2) + b2

pred = tf.nn.softmax(forward)

#定义损失函数
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices = 1))

loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = forward ,labels = y))

#超参数的设定
train_epochs = 50
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)
display_step = 1

learning_rate = 0.01
#优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
#optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)

#准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))



accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


# In[ ]:


#训练模型
startTime = time()

sess = tf.Session()
init = tf.global_variables_initializer()

sess.run(init)

for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer,feed_dict = {x:xs,y:ys})
        
    loss,acc = sess.run([loss_function,accuracy],feed_dict = {x:mnist.validation.images,y:mnist.validation.labels})
    
    if (epoch+1) % display_step == 0:
        print("Train Epoch:","%02d" % (epoch+1), "Loss=","{:.9f}".format(loss),"Accuracy= ","{:.4f}".format(acc))
        
#accu_test = sess.run(accuracy,feed_dict = {x:mnist.test.images,y:mnist.test.labels})
#print("Test Accuracy :",acc_test)
duration = time() - startTime
print("Train Finished takes: {:.2f}".format(duration))


# In[ ]:


import numpy as np


# In[ ]:


prediction_result = sess.run(tf.argmax(pred,1),feed_dict = {x:mnist.test.images})
compare_lists = prediction_result == np.argmax(mnist.test.labels,1)
print(compare_lists)


# In[ ]:


err_lists = [i for i in range(len(compare_lists)) if compare_lists[i] == False]
print(err_lists,"\n", len(err_lists))


# In[ ]:


#定义输出错误分类的函数
def print_predict_err(labels,predictions):
    count = 0
    compare_list = (predictions == np.argmax(labels,1))
    err_lists = [i for i in range(len(compare_list)) if compare_list[i] == False]
    for index in err_lists:
        print("index=:" +str(index)+ "标签值=",np.argmax(labels[index]), "预测值= ",predictions[index])
        count = count + 1
    print("总计为:"+str(count))


# In[ ]:


print_predict_err(labels=mnist.test.labels,predictions = prediction_result)


# In[ ]:


#可视化查看预测错误的样本
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,index,num=10):
    fig = plt.gcf()
    fig.set_size_inches(10,12)
    if num > 25:
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)
        ax.imshow(np.reshape(images[index],(28,28)),cmap = "binary")
        title = "label=" + str(np.argmax(labels[index]))
        if len(prediction) >0:
            title += ".prediction=" +str(prediction[index])
        ax.set_title(title,fontsize = 10)
        ax.set_xticks([])
        ax.set_yticks([])
        index +=1
    plt.show()


# In[ ]:


plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,610,20)


# In[ ]:


#定义可视化函数
#参考09 MNIST手写数字识别进阶-多层神经网络与应用1-30

参考:慕课公开课 浙江大学吴明辉教授的 深度学习应用开发-TensorFlow实践 

发布了42 篇原创文章 · 获赞 6 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/zkyxgs518/article/details/102892514
今日推荐