生成全连接网络模型,并输入自定义图片验证

生成全连接网络模型,并输入自定义图片验证
参考:https://github.com/niektemme/tensorflow-mnist-predict/find/master

a.py(全连接模型训练)

a.py(全连接模型训练)

import tensorflow.examples.tutorials.mnist.input_data as input_data
input_data = input_data.read_data_sets("MNIST/data",one_hot=True)
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2

class Net:
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32, shape=[None, 784])
        self.y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
        self.w1 = tf.Variable(tf.truncated_normal(shape=[784, 512], stddev=tf.sqrt(1/512),dtype=tf.float32))
        self.b1 = tf.Variable(tf.zeros([512]))

        self.w2 = tf.Variable(tf.truncated_normal(shape=[512, 256], stddev=tf.sqrt(1 / 256), dtype=tf.float32))
        self.b2 = tf.Variable(tf.zeros([256]))

        self.w3 = tf.Variable(tf.truncated_normal(shape=[256, 128], stddev=tf.sqrt(1 / 128), dtype=tf.float32))
        self.b3 = tf.Variable(tf.zeros([128]))

        self.w4 = tf.Variable(tf.truncated_normal(shape=[128, 64], stddev=tf.sqrt(1 / 64), dtype=tf.float32))
        self.b4 = tf.Variable(tf.zeros([64]))

        self.w5 = tf.Variable(tf.truncated_normal(shape=[64, 10], stddev=tf.sqrt(1 / 10), dtype=tf.float32))
        self.b5 = tf.Variable(tf.zeros([10]))

    def forward(self):
        self.y1 = tf.sigmoid(tf.matmul(self.x, self.w1)+self.b1)
        self.y2 = tf.sigmoid(tf.matmul(self.y1, self.w2) + self.b2)
        self.y3 = tf.sigmoid(tf.matmul(self.y2, self.w3) + self.b3)
        self.y4 = tf.sigmoid(tf.matmul(self.y3, self.w4) + self.b4)
        self.y5 = tf.sigmoid(tf.matmul(self.y4, self.w5) + self.b5)

    def backward(self):
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.y5))
        self.optimizer = tf.train.AdamOptimizer().minimize(self.loss)


if __name__ == "__main__":

    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    # a = []
    # b = []
    # plt.ion()


    with tf.Session() as sess:

        sess.run(init)

        saver = tf.train.Saver()
        saver.save(sess, "save/data.ckpt")

        for i in range(5000):
            # if i in range(200):
            s1, s2 = input_data.train.next_batch(100)
            loss, optimizer,y_5 = sess.run([net.loss, net.optimizer,net.y5], feed_dict={net.x: s1, net.y: s2})
        # print(sess.run(tf.argmax(y5[0]),feed_dict={net.x:(1,a)}))

            if i%200 == 0:
                print(sess.run(tf.argmax(s2[0])),"------------->",sess.run(tf.argmax(y_5[0])))

                # print("****{}".format(j), loss, "****")
                # j = j + 1
            if i%500 == 0:
                predict = tf.equal(tf.argmax(net.y, 1), tf.argmax(net.y5, 1))
                # print("jdjjdjdjdjd",predict,"djkdfjegjiergreg")
                accuracy = tf.reduce_mean(tf.cast(predict, dtype=tf.float32))
                print(sess.run(accuracy,feed_dict={net.x:input_data.train.images,net.y:input_data.train.labels}))
            # a.append(i)
            # b.append(loss)
            # plt.clf()
            # plt.plot(a,b)
            # plt.pause(0.01)
b.py (输入数字图片,进行识别)


# import modules
import sys
import tensorflow as tf
from PIL import Image, ImageFilter
import a

def predictint(imvalue):
    """
    This function returns the predicted integer.
    The input is the pixel values from the imageprepare() function.
    """

    # Define the model (same as when creating the model file)
    net = a.Net()
    net.forward()
    net.backward()

    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

    """
    Load the model.ckpt file
    file is stored in the same directory as this python script is started
    Use the model to predict the integer. Integer is returend as list.
    Based on the documentatoin at
    https://www.tensorflow.org/versions/master/how_tos/variables/index.html
    """
    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, "save/data.ckpt")
        # print ("Model restored.")

        prediction = tf.argmax(net.y5, 1)

        # return prediction.eval(feed_dict={net.x:[imvalue]}, session=sess)
        return sess.run(prediction,feed_dict={net.x: [imvalue]})

def imageprepare(argv):
    """
    This function returns the pixel values.
    The imput is a png file location.
    """
    im = Image.open(argv).convert('L')
    width = float(im.size[0])
    height = float(im.size[1])
    newImage = Image.new('L', (28, 28), 255)  # creates white canvas of 28x28 pixels

    if width > height:  # check which dimension is bigger
        # Width is bigger. Width becomes 20 pixels.
        nheight = int(round((20.0 / width * height), 0))  # resize height according to ratio width
        if (nheight == 0):  # rare case but minimum is 1 pixel
            nheight = 1
        # resize and sharpen
        img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
        wtop = int(round(((28 - nheight) / 2), 0))  # caculate horizontal pozition
        newImage.paste(img, (4, wtop))  # paste resized image on white canvas
    else:
        # Height is bigger. Heigth becomes 20 pixels.
        nwidth = int(round((20.0 / height * width), 0))  # resize width according to ratio height
        if (nwidth == 0):  # rare case but minimum is 1 pixel
            nwidth = 1
        # resize and sharpen
        img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
        wleft = int(round(((28 - nwidth) / 2), 0))  # caculate vertical pozition
        newImage.paste(img, (wleft, 4))  # paste resized image on white canvas

    # newImage.save("sample.png")

    tv = list(newImage.getdata())  # get pixel values


    # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.
    tva = [(255 - x) * 1.0 / 255.0 for x in tv]
    print(tva)
    return tva
    # print(tva)



if __name__ == "__main__":
    """
    Main function.
    """
    imvalue = imageprepare("9.png")#1/2/3/6/7
    predint = predictint(imvalue)
    print(predint[0])  # first value in list

猜你喜欢

转载自blog.csdn.net/qq_34649170/article/details/89473872