tensorflow识别mnist并用PIL显示

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from PIL import ImageFont
from PIL import Image
from PIL import ImageDraw
import numpy as np
mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)

class Net:
    def __init__(self):
        self.x = tf.placeholder(shape=[None,784],dtype=tf.float32)
        self.y = tf.placeholder(shape=[None,10],dtype=tf.float32)
        self.w1 = tf.Variable(tf.truncated_normal(shape=[784,512],dtype=tf.float32,stddev=0.1))
        self.b1 = tf.Variable(tf.zeros([512]))
        self.w2 = tf.Variable(tf.truncated_normal(shape=[512,256], dtype=tf.float32, stddev=0.1))
        self.b2 = tf.Variable(tf.zeros([256]))
        self.w3 = tf.Variable(tf.truncated_normal(shape=[256, 10], dtype=tf.float32, stddev=0.1))
        self.b3 = tf.Variable(tf.zeros([10]))
    def forward(self):
        self.y1 = tf.nn.dropout(tf.nn.relu(tf.layers.batch_normalization(tf.matmul(self.x,self.w1)+self.b1)),keep_prob=0.5)
        self.y2 = tf.nn.dropout(tf.nn.relu(tf.layers.batch_normalization(tf.matmul(self.y1, self.w2) + self.b2)),keep_prob=0.5)
        self.y3 = tf.nn.softmax(tf.layers.batch_normalization(tf.matmul(self.y2, self.w3) + self.b3))
    def backward(self):
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.y3,labels=self.y))
        self.opt = tf.train.AdamOptimizer(0.0001).minimize(self.loss)
        self.acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.y3,axis=1),tf.argmax(self.y,axis=1)),dtype=tf.float32))
if __name__ == '__main__':
    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(1000):
            x,y = mnist.train.next_batch(100)
            loss,_ = sess.run([net.loss,net.opt],feed_dict={net.x:x,net.y:y})
            if epoch % 100 == 0:
                xs, ys = mnist.validation.next_batch(100)
                error, _,acc,out = sess.run([net.loss, net.opt,net.acc,net.y3], feed_dict={net.x: x, net.y: y})
                imgarray = np.reshape(xs[0],[28,28])*255
                img = Image.fromarray(imgarray)#从数组里面转换为图片
                imgdraw = ImageDraw.ImageDraw(img)
                font = ImageFont.truetype(font='msyh.ttf',size=10)
                lable = np.argmax(out[0])
                imgdraw.text(xy=(0,0),text=str(lable),font=font,fill=255)
                img.show()

猜你喜欢

转载自blog.csdn.net/weixin_38241876/article/details/89533082