25-Mnist05

from tensorflow.examples.tutorials.mnist.input_data import read_data_sets
import numpy as np
import cv2
import tensorflow as tf
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

class Config:
    def __init__(self):
        self.sample_path = '../deeplearning_ai12/p07_mnist/MNIST_data'
        self.lr = 0.001
        self.epoches = 200
        self.batch_size = 20
        self.eps = 1e-10
        self.base_filters = 16  # should be 32 at least

        self.name = 'mnist04'
        self.save_path = '../models/{name}/{name}'.format(name=self.name)


class Tensors:
    def __init__(self, config: Config):
        self.config = config
        self.x = tf.placeholder(tf.float32, [None, 784], 'x')
        x = tf.reshape(self.x, [-1, 28, 28, 1])  # [-1, 28, 28, 1]
        logits = self.get_logits(x)  # [-1, 10]
        self.y_predict = tf.argmax(logits, axis=1, output_type=tf.int32)  # [-1]
        p = tf.nn.softmax(logits)   # [-1, 10]

        self.y = tf.placeholder(tf.int32, [None], 'y')
        y = tf.one_hot(self.y, 10)  # [-1, 10]

        p = tf.maximum(p, config.eps)
        self.loss = -tf.reduce_mean(tf.reduce_sum(y * tf.log(p), axis=1))
        opt = tf.train.AdamOptimizer(config.lr)
        self.train_op = opt.minimize(self.loss)

        self.precise = tf.reduce_mean(tf.cast(tf.equal(self.y, self.y_predict), tf.float32))

        params = 0
        for var in tf.trainable_variables():
            ps = _params(var.shape)
            print(var.name, var.shape, ps)
            params += ps
        print('-' * 200)
        print('Total:', params)

    def get_logits(self, x):
        """

        :param x: [-1, 28, 28, 1]
        :return: [-1, 10]
        """
        config = self.config
        filters = config.base_filters
        x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same',
                             activation=tf.nn.relu, name='conv1')  # [-1, 28, 28, 32]
        for i in range(2):
            filters *= 2
            x = tf.layers.conv2d(x, filters, 3, 1, 'same',
                                 name='conv2_%d' % i)  # [-1, 28, 28, 64]
            x = tf.layers.max_pooling2d(x, (2, 2), (2, 2), 'same')  # [-1, 14, 14, 64]
            x = tf.nn.relu(x)

        # x: [-1, 7, 7, 128]
        x = tf.layers.flatten(x)  # [-1, 7*7*128]
        x = tf.layers.dense(x, 1000, activation=tf.nn.relu, name='dense1')
        x = tf.layers.dense(x, 10, name='dense2')  # [-1, 10]
        return x


def _params(shape):
    result = 1
    for sh in shape:
        result *= sh.value
    return result


class Samples:
    def __init__(self, config):
        ds = read_data_sets(config.sample_path)

        self.train = SubSamples(ds.train)
        self.validation = SubSamples(ds.validation)
        self.test = SubSamples(ds.test)


class SubSamples:
    def __init__(self, data):
        self.data = data

    def num_examples(self):
        return self.data.num_examples

    def next_batch(self, batch_size):
        return self.data.next_batch(batch_size)  # xs: [batch_size, 784], ys: [batch_size]


def show_imgs(xs, ys):
    print(ys)
    xs = np.reshape(xs, [-1, 28, 28])
    xs = np.transpose(xs, [1, 0, 2])  # [28, -1, 28]
    xs = np.reshape(xs, [28, -1, 28 * 20])  # [28, -1, 560],
    xs = np.transpose(xs, [1, 0, 2])  # [-1, 28, 560]
    xs = np.reshape(xs, [-1, 28 * 20])

    cv2.imshow('My digits', xs)
    cv2.waitKey()


class App:
    def __init__(self, config: Config):
        self.config = config
        self.samples = Samples(config)

        g = tf.Graph()
        with g.as_default():
            self.tensors = Tensors(config)
            self.session = tf.Session(graph=g)
            self.saver = tf.train.Saver()

            try:
                self.saver.restore(self.session, config.save_path)
                print('Restore the model from %s successfully' % config.save_path)
            except:
                print('Fail to restore the model from %s, use a new model instead' % config.save_path)
                self.session.run(tf.global_variables_initializer())

    def close(self):
        self.session.close()

    def train(self):
        train_samples = self.samples.train
        config = self.config
        ts = self.tensors

        for epoch in range(config.epoches):
            batches = train_samples.num_examples() // config.batch_size
            for batch in range(batches):
                xs, ys = train_samples.next_batch(config.batch_size)
                _, loss_v = self.session.run([ts.train_op, ts.loss], {ts.x: xs, ts.y: ys})

                xs, ys = self.samples.validation.next_batch(config.batch_size)
                precise_v = self.session.run(ts.precise, {ts.x: xs, ts.y: ys})

                print('Epoch: %d, batch %d: loss=%.6f, precise=%.6f' % (epoch, batch, loss_v, precise_v))
            self.saver.save(self.session, config.save_path)
            print('Model saved into', config.save_path)
        print('Training is finished!')


if __name__ == '__main__':
    config = Config()
    app = App(config)

    app.train()
    app.close()

D:\Anaconda\python.exe D:/AI20/HJZ/05-深度学习项目/deeplearning_20/p25_mnist/mnist05_precise.py
WARNING:tensorflow:From D:/AI20/HJZ/05-深度学习项目/deeplearning_20/p25_mnist/mnist05_precise.py:81: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting ../deeplearning_ai12/p07_mnist/MNIST_data\train-images-idx3-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../deeplearning_ai12/p07_mnist/MNIST_data\train-labels-idx1-ubyte.gz
Extracting ../deeplearning_ai12/p07_mnist/MNIST_data\t10k-images-idx3-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../deeplearning_ai12/p07_mnist/MNIST_data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
conv1/kernel:0 (3, 3, 1, 16) 144
conv1/bias:0 (16,) 16
conv2_0/kernel:0 (3, 3, 16, 32) 4608
conv2_0/bias:0 (32,) 32
conv2_1/kernel:0 (3, 3, 32, 64) 18432
conv2_1/bias:0 (64,) 64
dense1/kernel:0 (3136, 1000) 3136000
dense1/bias:0 (1000,) 1000
dense2/kernel:0 (1000, 10) 10000
dense2/bias:0 (10,) 10
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Total: 3170306
2020-03-11 11:46:21.378514: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
2020-03-11 11:46:21.429291: W tensorflow/core/framework/op_kernel.cc:1318] OP_REQUIRES failed at save_restore_tensor.cc:170 : Invalid argument: Unsuccessful TensorSliceReader constructor: Failed to get matching files on ../models/mnist04/mnist04: Not found: FindFirstFile failed for: ../models/mnist04 : ϵͳ�Ҳ���ָ����·����
; No such process

发布了125 篇原创文章 · 获赞 2 · 访问量 2613

猜你喜欢

转载自blog.csdn.net/HJZ11/article/details/104793340
今日推荐