我的第一个tensorflow程序

第一个tensorflow是从网上抄来的,但是还是爬了个大坑,在预测文件中的图片转换为28*28尺寸的时候用PIL一直报错(原作者的代码),后来改用cv2模块resize问题就解决了,这是一个关于数字识别的程序,程序能够在一张只有一个0-9数字的图片中准确识别出数字是多少,准确率高达99+%,然后我用PyQt5封装了一下,使其可视化。环境为Python3+tensorflow2.0+PyQt5,首先创建一个python project,然后往里面添加文件夹,然后在v4_cnn目录下创建三个文件,mainUI.py,predict.py,train.py三个文件 ,ckpt文件目录是没有的,运行程序后生成的,想要运行该程序,必须先运行训练代码,train.py文件,然后再运行主UI文件maiUI.py文件。

训练文件train.py,代码如下:

import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np

y1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4

class CNN(object):
    def __init__(self):
        model = models.Sequential()
        # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
        model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第2层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第3层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))
        model.summary()
        self.model = model

class DataSource(object):
    def __init__(self):
        # mnist数据集存储的位置,如何不存在将自动下载
        data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'
        (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
        # 6万张训练图片,1万张测试图片
        train_images = train_images.reshape((60000, 28, 28, 1))
        test_images = test_images.reshape((10000, 28, 28, 1))
        # 像素值映射到 0 - 1 之间
        train_images, test_images = train_images / 255.0, test_images / 255.0

        self.train_images, self.train_labels = train_images, train_labels
        self.test_images, self.test_labels = test_images, test_labels

class Train:
    def __init__(self):
        self.cnn = CNN()
        self.data = DataSource()

    def train(self):
        check_path = './ckpt/cp-{epoch:04d}.ckpt'
        # period 每隔5epoch保存一次
        save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)
        self.cnn.model.compile(optimizer='adam',
                               loss='sparse_categorical_crossentropy',
                               metrics=['accuracy'])
        self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])
        test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)
        print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))


if __name__ == "__main__":
    app = Train()
    app.train()

预测文件predict.py,代码如下:

import tensorflow as tf
from PIL import Image
import numpy as np
from v4_cnn.train import CNN
import cv2


class Predict(object):
    def __init__(self):
        latest = tf.train.latest_checkpoint('./ckpt')
        self.cnn = CNN()
        # 恢复网络权重
        self.cnn.model.load_weights(latest)

    def predict(self, image_path):
        # 以黑白方式读取图片
        img = Image.open(image_path).convert('L') #爬了个大坑
        img = np.asarray(img)
        img = cv2.resize(img,(28,28))
        flatten_img = np.reshape(img, (28, 28, 1))
        x = np.array([1 - flatten_img])
        # API refer: https://keras.io/models/model/
        self.y = self.cnn.model.predict(x)

主UI文件mainUI.py,代码如下:

import cv2
import sys
from PyQt5 import QtGui
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QHBoxLayout, QMainWindow, QDockWidget, QPushButton, \
    QVBoxLayout, QTextEdit, QFileDialog
from v4_cnn.predict import Predict
import numpy as np

class QPixmapDemo(QMainWindow):
    def __init__(self):
        super().__init__()
        self.txt = {0:'0', 1:'1', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9'}
        self.setWindowTitle('picture')
        self.wgt = QWidget()
        # self.wgt.resize(600, 500)
        self.imgLabel = QLabel()
        self.imgLabel.resize(600, 600)  # 设置label的大小,图片会适配label的大小
        self.hbox = QHBoxLayout()
        self.hbox.addWidget(self.imgLabel)
        self.wgt.setLayout(self.hbox)
        self.setCentralWidget(self.wgt)
        self.docker = docker(self)
        self.addDockWidget(Qt.LeftDockWidgetArea,self.docker)
        self.docker.btn_openFile.clicked.connect(self.openFile)
        self.docker.btn_startDiscern.clicked.connect(self.start)
        self.resize(800,600)

    def openFile(self):
        self.file, filetype = QFileDialog.getOpenFileName(self,
                                                          "选择只有一个数字的图片",
                                                          "./",
                                                          "All Files (*);;Text Files (*.txt)")
        if self.file is not None:
            self.setImage(self.file)

    def start(self):
        discern = Predict()
        discern.predict(self.file)
        num = np.argmax(discern.y[0])
        self.docker.texEdit.setText(str(num))

    def setImage(self, file):
        img = cv2.imread(file)  # opencv读取图片
        img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # opencv读取的bgr格式图片转换成rgb格式
        _image = QtGui.QImage(img2[:], img2.shape[1], img2.shape[0], img2.shape[1] * 3,
                              QtGui.QImage.Format_RGB888)  # pyqt5转换成自己能放的图片格式
        jpg_out = QtGui.QPixmap(_image).scaled(self.imgLabel.width(), self.imgLabel.height())  # 设置图片大小
        self.imgLabel.setPixmap(jpg_out)  # 设置图片显示


class docker(QDockWidget):
    def __init__(self, parent):
        super().__init__(parent)
        self.btn_openFile = QPushButton('打开图片')
        self.btn_startDiscern = QPushButton('开始识别')
        self.texEdit = QTextEdit()
        self.vbox = QVBoxLayout()
        self.vbox.addWidget(self.btn_openFile)
        self.vbox.addWidget(self.btn_startDiscern)
        self.vbox.addWidget(self.texEdit)
        self.wgt = QWidget()
        self.wgt.setLayout(self.vbox)
        self.setWidget(self.wgt)


if __name__ == '__main__':
    app = QApplication(sys.argv)
    win = QPixmapDemo()
    win.show()
    sys.exit(app.exec_())

运行结果:

发布了92 篇原创文章 · 获赞 15 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/zZzZzZ__/article/details/103723563
今日推荐