(系列更新完毕)深度学习零基础使用 TensorFlow 框架跑 MNIST 数据集的第二天:训练模型

1. Introduction

今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习训练网络。本 blog 主要记录一个学习的路径以及学习资料的汇总。

注意:这是用 Python 2.7 版本写的代码

第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108245969

第二天(训练网络):https://blog.csdn.net/qq_36627158/article/details/108315239

第三天(测试网络):https://blog.csdn.net/qq_36627158/article/details/108321673

第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108397018

2. Code(mnist_train.py)

import tensorflow as tf
import matplotlib.pyplot as plt
import mnist_lenet
from tensorflow.examples.tutorials.mnist import input_data


batch_size = 64
learn_rate = 0.01
iteration = 1500


def train_model(train_dataset):
    images_holder = tf.placeholder(
        dtype=tf.float32,
        shape=[batch_size, 28, 28, 1]
    )
    labels_holder = tf.placeholder(
        dtype=tf.float32,
        shape=[batch_size, 10]
    )

    label_predict = mnist_lenet.build_model_and_forward(images_holder)

    # get every single image's loss
    cross_entropy  = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=label_predict,
        labels=tf.argmax(labels_holder, axis=1)
    )
    # get the mean loss in a batch of image
    loss = tf.reduce_mean(cross_entropy)

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learn_rate)
    train_update_op = optimizer.minimize(loss)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        print("Start training:")

        loss_plt = []

        for i in range(iteration):
            batch_images, batch_labels = train_dataset.next_batch(batch_size)
            # batch_images.shape(batch_size, 784)
            # batch_labels.shape(batch_size, 10)
            batch_images_reshaped = tf.reshape(
                tensor=batch_images,
                shape=[batch_size, 28, 28, 1]
            )

            loss_value, _ = sess.run(
                [loss, train_update_op],
                feed_dict={
                    images_holder: batch_images_reshaped.eval(),
                    labels_holder: batch_labels
                }
            )

            if (i+1) % 50 == 0:
                print "After", (i+1), "iteration, loss on training batch is", loss_value
                loss_plt.append(loss_value)

                saver.save(sess, "models/model.ckpt", global_step=i+1)

        print("End training")
        plt.plot(loss_plt, color=(0, 0, 0), label='loss')
        plt.legend()
        plt.show()


if __name__ == '__main__':
    mnist_data = input_data.read_data_sets('MNIST_data/', one_hot=True)

    if mnist_data != None:
        print("Load data completely!")
        train_model(mnist_data.train)

3. Materials

1、tensorflow 官方文档

https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/

2、input_data.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/input_data.py

4、Code Details

1、注意,TensorFlow 2.0 版本没有 tensorflow.examples.tutorials 模块

解决方案:

https://blog.csdn.net/likeyou1314918273/article/details/107535539?utm_medium=distribute.pc_relevant.none-task-blog-title-8&spm=1001.2101.3001.4242

2、input_data 中 read_data_sets() 

one_hot 编码参数的作用https://blog.csdn.net/weiyumeizi/article/details/81502471

截取了其中的 read_data_sets() 函数的源码

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000,
                   seed=None,
                   source_url=DEFAULT_SOURCE_URL):
  if fake_data:

    def fake():
      return _DataSet([], [],
                      fake_data=True,
                      one_hot=one_hot,
                      dtype=dtype,
                      seed=seed)

    train = fake()
    validation = fake()
    test = fake()
    return _Datasets(train=train, validation=validation, test=test)

  if not source_url:  # empty string check
    source_url = DEFAULT_SOURCE_URL

  train_images_file = 'train-images-idx3-ubyte.gz'
  train_labels_file = 'train-labels-idx1-ubyte.gz'
  test_images_file = 't10k-images-idx3-ubyte.gz'
  test_labels_file = 't10k-labels-idx1-ubyte.gz'

  local_file = _maybe_download(train_images_file, train_dir,
                               source_url + train_images_file)
  with gfile.Open(local_file, 'rb') as f:
    train_images = _extract_images(f)

  local_file = _maybe_download(train_labels_file, train_dir,
                               source_url + train_labels_file)
  with gfile.Open(local_file, 'rb') as f:
    train_labels = _extract_labels(f, one_hot=one_hot)

  local_file = _maybe_download(test_images_file, train_dir,
                               source_url + test_images_file)
  with gfile.Open(local_file, 'rb') as f:
    test_images = _extract_images(f)

  local_file = _maybe_download(test_labels_file, train_dir,
                               source_url + test_labels_file)
  with gfile.Open(local_file, 'rb') as f:
    test_labels = _extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError(
        'Validation size should be between 0 and {}. Received: {}.'.format(
            len(train_images), validation_size))

  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = _DataSet(train_images, train_labels, **options)
  validation = _DataSet(validation_images, validation_labels, **options)
  test = _DataSet(test_images, test_labels, **options)

  return _Datasets(train=train, validation=validation, test=test)

3、tf.placeholder()

https://blog.csdn.net/kdongyi/article/details/82343712

4、tf.nn.sparse_softmax_cross_entropy_with_logits()

https://blog.csdn.net/ZJRN1027/article/details/80199248

5、tf.argmax()

返回最大的那个数值所在的下标

https://blog.csdn.net/qq575379110/article/details/70538051/

6、tf.reduce_mean()

https://blog.csdn.net/dcrmg/article/details/79797826

7、tf.train.GradientDescentOptimizer()

https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/GradientDescentOptimizer

https://www.cnblogs.com/smallredness/p/11203250.html

8、tf.train.Saver()

 

https://blog.csdn.net/yz19930510/article/details/80324389

9、tf.global_variables_initializer()

10、mnist.train.next_batch到底完成了什么工作?

11、这段代码中的 “_” 是什么意思?

loss_value, _ = sess.run(
                [loss, train_update_op],
                feed_dict={
                    images_holder: batch_images_reshaped.eval(),
                    labels_holder: batch_labels
                }

我试着输出了一下,就是 None

12、这段代码中的 batch_images_reshaped 要使用 eval() 函数

loss_value, _ = sess.run(
                [loss, train_update_op],
                feed_dict={
                    images_holder: batch_images_reshaped.eval(),
                    labels_holder: batch_labels
                }

如果不用,会报错:TensorFlow 报错 TypeError: The value of a feed cannot be a tf.Tensor object

https://blog.csdn.net/sdnuwjw/article/details/85935373?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

猜你喜欢

转载自blog.csdn.net/qq_36627158/article/details/108315239
今日推荐