深度学习 tensorflow 猫狗大战 自己的数据集

本次学习主要参考优酷:Tensorflow tutorial Cats vs. dogs 系列视频  

视频连接:http://i.youku.com/deeplearning101 

数据集链接:http://pan.baidu.com/s/1dFd8kmt 密码:psor

运行环境:win10,64位,TensorFlow CPU版本,电脑显卡不行,整个过程训练了十个小时(10000个steps)

训练目录:D:\cat_VS_dog\cats_vs_dogs\data\train

生成模型目录:D:\cat_VS_dog\cats_vs_dogs\logs\train (每隔2000步记录一次)

主要用到三个python文件:

1. input_data.py

2. model.py

3. training.py

1.input_data.py

[python] view plain copy
print ?
  1. <code class="language-python">#By @Kevin Xu  
  2. #[email protected]  
  3. #Youtube: https://www.youtube.com/channel/UCVCSn4qQXTDAtGWpWAe4Plw  
  4. #  
  5. #The aim of this project is to use TensorFlow to process our own data.  
  6. #    - input_data.py:  read in data and generate batches  
  7. #    - model: build the model architecture  
  8. #    - training: train  
  9.   
  10.   
  11. # I used Ubuntu with Python 3.5, TensorFlow 1.0*, other OS should also be good.  
  12. # With current settings, 10000 traing steps needed 50 minutes on my laptop.  
  13.   
  14.   
  15.   
  16.   
  17. # data: cats vs. dogs from Kaggle  
  18. # Download link: https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data  
  19. # data size: ~540M  
  20.   
  21. # How to run?  
  22. # 1. run the training.py once  
  23. # 2. call the run_training() in the console to train the model.  
  24.   
  25. # Note:   
  26. # it is suggested to restart your kenel to train the model multiple times   
  27. #(in order to clear all the variables in the memory)  
  28. # Otherwise errors may occur: conv1/weights/biases already exist......  
  29.   
  30. #%%  
  31. import tensorflow as tf  
  32. import numpy as np  
  33. import os  
  34. #%%  
  35.   
  36. # you need to change this to your data directory  
  37. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'  
  38. train_dir = 'D:/cat_VS_dog/cats_vs_dogs/data/train/'   #My dir--20170727-csq  
  39.   
  40. def get_files(file_dir):  
  41.     ''''' 
  42.     Args: 
  43.         file_dir: file directory 
  44.     Returns: 
  45.         list of images and labels 
  46.     '''  
  47.     cats = []  
  48.     label_cats = []  
  49.     dogs = []  
  50.     label_dogs = []  
  51.     for file in os.listdir(file_dir):  
  52.         name = file.split(sep='.')  
  53.         if name[0]=='cat':  
  54.             cats.append(file_dir + file)  
  55.             label_cats.append(0)  
  56.         else:  
  57.             dogs.append(file_dir + file)  
  58.             label_dogs.append(1)  
  59.     print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs)))  
  60.       
  61.     image_list = np.hstack((cats, dogs))  
  62.     label_list = np.hstack((label_cats, label_dogs))  
  63.       
  64.     temp = np.array([image_list, label_list])  
  65.     temp = temp.transpose()  
  66.     np.random.shuffle(temp)  
  67.       
  68.     image_list = list(temp[:, 0])  
  69.     label_list = list(temp[:, 1])  
  70.     label_list = [int(i) for i in label_list]  
  71.          
  72.     return image_list, label_list  
  73.   
  74.   
  75. #%%  
  76. def get_batch(image, label, image_W, image_H, batch_size, capacity):  
  77.     ''''' 
  78.     Args: 
  79.         image: list type 
  80.         label: list type 
  81.         image_W: image width 
  82.         image_H: image height 
  83.         batch_size: batch size 
  84.         capacity: the maximum elements in queue 
  85.     Returns: 
  86.         image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32 
  87.         label_batch: 1D tensor [batch_size], dtype=tf.int32 
  88.     '''  
  89.       
  90.     image = tf.cast(image, tf.string)  
  91.     label = tf.cast(label, tf.int32)  
  92.   
  93.   
  94.     # make an input queue  
  95.     input_queue = tf.train.slice_input_producer([image, label])  
  96.       
  97.     label = input_queue[1]  
  98.     image_contents = tf.read_file(input_queue[0])  
  99.     image = tf.image.decode_jpeg(image_contents, channels=3)  
  100.       
  101.     ######################################  
  102.     # data argumentation should go to here  
  103.     ######################################  
  104.       
  105.     image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)  
  106.       
  107.     # if you want to test the generated batches of images, you might want to comment the following line.  
  108.     image = tf.image.per_image_standardization(image)  
  109.       
  110.     image_batch, label_batch = tf.train.batch([image, label],  
  111.                                                 batch_size= batch_size,  
  112.                                                 num_threads= 64,   
  113.                                                 capacity = capacity)  
  114.       
  115.     #you can also use shuffle_batch   
  116. #    image_batch, label_batch = tf.train.shuffle_batch([image,label],  
  117. #                                                      batch_size=BATCH_SIZE,  
  118. #                                                      num_threads=64,  
  119. #                                                      capacity=CAPACITY,  
  120. #                                                      min_after_dequeue=CAPACITY-1)  
  121.       
  122.     label_batch = tf.reshape(label_batch, [batch_size])  
  123.     image_batch = tf.cast(image_batch, tf.float32)  
  124.       
  125.     return image_batch, label_batch  
  126.   
  127. #%% TEST  
  128. # To test the generated batches of images  
  129. # When training the model, DO comment the following codes  
  130.   
  131.   
  132.   
  133.   
  134.   
  135. import matplotlib.pyplot as plt  
  136.   
  137. BATCH_SIZE = 2  
  138. CAPACITY = 256  
  139. IMG_W = 208  
  140. IMG_H = 208  
  141.   
  142.   
  143. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'  
  144. train_dir = 'D:/cat_VS_dog/cats_vs_dogs/data/train/'   #My dir--20170727-csq  
  145. image_list, label_list = get_files(train_dir)  
  146. image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)  
  147.   
  148. with tf.Session() as sess:  
  149.     i = 0  
  150.     coord = tf.train.Coordinator()  
  151.     threads = tf.train.start_queue_runners(coord=coord)  
  152.       
  153.     try:  
  154.         while not coord.should_stop() and i<1:  
  155.               
  156.             img, label = sess.run([image_batch, label_batch])  
  157.               
  158.             # just test one batch  
  159.             for j in np.arange(BATCH_SIZE):  
  160.                 print('label: %d' %label[j])  #j-index of quene of Batch_size  
  161.                 plt.imshow(img[j,:,:,:])  
  162.                 plt.show()  
  163.             i+=1  
  164.               
  165.     except tf.errors.OutOfRangeError:  
  166.         print('done!')  
  167.     finally:  
  168.         coord.request_stop()  
  169.     coord.join(threads)  
  170.   
  171. #%%</code>  
  1. #By @Kevin Xu
  2. #Youtube: https://www.youtube.com/channel/UCVCSn4qQXTDAtGWpWAe4Plw
  3. #
  4. #The aim of this project is to use TensorFlow to process our own data.
  5. # - input_data.py: read in data and generate batches
  6. # - model: build the model architecture
  7. # - training: train
  8. # I used Ubuntu with Python 3.5, TensorFlow 1.0*, other OS should also be good.
  9. # With current settings, 10000 traing steps needed 50 minutes on my laptop.
  10. # data: cats vs. dogs from Kaggle
  11. # Download link: https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
  12. # data size: ~540M
  13. # How to run?
  14. # 1. run the training.py once
  15. # 2. call the run_training() in the console to train the model.
  16. # Note:
  17. # it is suggested to restart your kenel to train the model multiple times
  18. #(in order to clear all the variables in the memory)
  19. # Otherwise errors may occur: conv1/weights/biases already exist......
  20. #%%
  21. import tensorflow as tf
  22. import numpy as np
  23. import os
  24. #%%
  25. # you need to change this to your data directory
  26. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'
  27. train_dir = 'D:/cat_VS_dog/cats_vs_dogs/data/train/' #My dir--20170727-csq
  28. def get_files(file_dir):
  29. '''
  30. Args:
  31. file_dir: file directory
  32. Returns:
  33. list of images and labels
  34. '''
  35. cats = []
  36. label_cats = []
  37. dogs = []
  38. label_dogs = []
  39. for file in os.listdir(file_dir):
  40. name = file.split(sep= '.')
  41. if name[ 0]== 'cat':
  42. cats.append(file_dir + file)
  43. label_cats.append( 0)
  44. else:
  45. dogs.append(file_dir + file)
  46. label_dogs.append( 1)
  47. print( 'There are %d cats\nThere are %d dogs' %(len(cats), len(dogs)))
  48. image_list = np.hstack((cats, dogs))
  49. label_list = np.hstack((label_cats, label_dogs))
  50. temp = np.array([image_list, label_list])
  51. temp = temp.transpose()
  52. np.random.shuffle(temp)
  53. image_list = list(temp[:, 0])
  54. label_list = list(temp[:, 1])
  55. label_list = [int(i) for i in label_list]
  56. return image_list, label_list
  57. #%%
  58. def get_batch(image, label, image_W, image_H, batch_size, capacity):
  59. '''
  60. Args:
  61. image: list type
  62. label: list type
  63. image_W: image width
  64. image_H: image height
  65. batch_size: batch size
  66. capacity: the maximum elements in queue
  67. Returns:
  68. image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32
  69. label_batch: 1D tensor [batch_size], dtype=tf.int32
  70. '''
  71. image = tf.cast(image, tf.string)
  72. label = tf.cast(label, tf.int32)
  73. # make an input queue
  74. input_queue = tf.train.slice_input_producer([image, label])
  75. label = input_queue[ 1]
  76. image_contents = tf.read_file(input_queue[ 0])
  77. image = tf.image.decode_jpeg(image_contents, channels= 3)
  78. ######################################
  79. # data argumentation should go to here
  80. ######################################
  81. image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
  82. # if you want to test the generated batches of images, you might want to comment the following line.
  83. image = tf.image.per_image_standardization(image)
  84. image_batch, label_batch = tf.train.batch([image, label],
  85. batch_size= batch_size,
  86. num_threads= 64,
  87. capacity = capacity)
  88. #you can also use shuffle_batch
  89. # image_batch, label_batch = tf.train.shuffle_batch([image,label],
  90. # batch_size=BATCH_SIZE,
  91. # num_threads=64,
  92. # capacity=CAPACITY,
  93. # min_after_dequeue=CAPACITY-1)
  94. label_batch = tf.reshape(label_batch, [batch_size])
  95. image_batch = tf.cast(image_batch, tf.float32)
  96. return image_batch, label_batch
  97. #%% TEST
  98. # To test the generated batches of images
  99. # When training the model, DO comment the following codes
  100. import matplotlib.pyplot as plt
  101. BATCH_SIZE = 2
  102. CAPACITY = 256
  103. IMG_W = 208
  104. IMG_H = 208
  105. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'
  106. train_dir = 'D:/cat_VS_dog/cats_vs_dogs/data/train/' #My dir--20170727-csq
  107. image_list, label_list = get_files(train_dir)
  108. image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
  109. with tf.Session() as sess:
  110. i = 0
  111. coord = tf.train.Coordinator()
  112. threads = tf.train.start_queue_runners(coord=coord)
  113. try:
  114. while not coord.should_stop() and i< 1:
  115. img, label = sess.run([image_batch, label_batch])
  116. # just test one batch
  117. for j in np.arange(BATCH_SIZE):
  118. print( 'label: %d' %label[j]) #j-index of quene of Batch_size
  119. plt.imshow(img[j,:,:,:])
  120. plt.show()
  121. i+= 1
  122. except tf.errors.OutOfRangeError:
  123. print( 'done!')
  124. finally:
  125. coord.request_stop()
  126. coord.join(threads)
  127. #%%







2.model.py

  1. #By @Kevin Xu
  2. #Youtube: https://www.youtube.com/channel/UCVCSn4qQXTDAtGWpWAe4Plw
  3. #
  4. #The aim of this project is to use TensorFlow to process our own data.
  5. # - input_data.py: read in data and generate batches
  6. # - model: build the model architecture
  7. # - training: train
  8. # data: cats vs. dogs from Kaggle
  9. # Download link: https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
  10. # data size: ~540M
  11. # How to run?
  12. # 1. run the training.py once
  13. # 2. call the run_training() in the console to train the model.
  14. #%%
  15. import tensorflow as tf
  16. #%%
  17. def inference(images, batch_size, n_classes):
  18. '''Build the model
  19. Args:
  20. images: image batch, 4D tensor, tf.float32, [batch_size, width, height, channels]
  21. Returns:
  22. output tensor with the computed logits, float, [batch_size, n_classes]
  23. '''
  24. #conv1, shape = [kernel size, kernel size, channels, kernel numbers]
  25. with tf.variable_scope( 'conv1') as scope:
  26. weights = tf.get_variable( 'weights',
  27. shape = [ 3, 3, 3, 16],
  28. dtype = tf.float32,
  29. initializer=tf.truncated_normal_initializer(stddev= 0.1,dtype=tf.float32))
  30. biases = tf.get_variable( 'biases',
  31. shape=[ 16],
  32. dtype=tf.float32,
  33. initializer=tf.constant_initializer( 0.1))
  34. conv = tf.nn.conv2d(images, weights, strides=[ 1, 1, 1, 1], padding= 'SAME')
  35. pre_activation = tf.nn.bias_add(conv, biases)
  36. conv1 = tf.nn.relu(pre_activation, name= scope.name)
  37. #pool1 and norm1
  38. with tf.variable_scope( 'pooling1_lrn') as scope:
  39. pool1 = tf.nn.max_pool(conv1, ksize=[ 1, 3, 3, 1],strides=[ 1, 2, 2, 1],
  40. padding= 'SAME', name= 'pooling1')
  41. norm1 = tf.nn.lrn(pool1, depth_radius= 4, bias= 1.0, alpha= 0.001/ 9.0,
  42. beta= 0.75,name= 'norm1')
  43. #conv2
  44. with tf.variable_scope( 'conv2') as scope:
  45. weights = tf.get_variable( 'weights',
  46. shape=[ 3, 3, 16, 16],
  47. dtype=tf.float32,
  48. initializer=tf.truncated_normal_initializer(stddev= 0.1,dtype=tf.float32))
  49. biases = tf.get_variable( 'biases',
  50. shape=[ 16],
  51. dtype=tf.float32,
  52. initializer=tf.constant_initializer( 0.1))
  53. conv = tf.nn.conv2d(norm1, weights, strides=[ 1, 1, 1, 1],padding= 'SAME')
  54. pre_activation = tf.nn.bias_add(conv, biases)
  55. conv2 = tf.nn.relu(pre_activation, name= 'conv2')
  56. #pool2 and norm2
  57. with tf.variable_scope( 'pooling2_lrn') as scope:
  58. norm2 = tf.nn.lrn(conv2, depth_radius= 4, bias= 1.0, alpha= 0.001/ 9.0,
  59. beta= 0.75,name= 'norm2')
  60. pool2 = tf.nn.max_pool(norm2, ksize=[ 1, 3, 3, 1], strides=[ 1, 1, 1, 1],
  61. padding= 'SAME',name= 'pooling2')
  62. #local3
  63. with tf.variable_scope( 'local3') as scope:
  64. reshape = tf.reshape(pool2, shape=[batch_size, -1])
  65. dim = reshape.get_shape()[ 1].value
  66. weights = tf.get_variable( 'weights',
  67. shape=[dim, 128],
  68. dtype=tf.float32,
  69. initializer=tf.truncated_normal_initializer(stddev= 0.005,dtype=tf.float32))
  70. biases = tf.get_variable( 'biases',
  71. shape=[ 128],
  72. dtype=tf.float32,
  73. initializer=tf.constant_initializer( 0.1))
  74. local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
  75. #local4
  76. with tf.variable_scope( 'local4') as scope:
  77. weights = tf.get_variable( 'weights',
  78. shape=[ 128, 128],
  79. dtype=tf.float32,
  80. initializer=tf.truncated_normal_initializer(stddev= 0.005,dtype=tf.float32))
  81. biases = tf.get_variable( 'biases',
  82. shape=[ 128],
  83. dtype=tf.float32,
  84. initializer=tf.constant_initializer( 0.1))
  85. local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name= 'local4')
  86. # softmax
  87. with tf.variable_scope( 'softmax_linear') as scope:
  88. weights = tf.get_variable( 'softmax_linear',
  89. shape=[ 128, n_classes],
  90. dtype=tf.float32,
  91. initializer=tf.truncated_normal_initializer(stddev= 0.005,dtype=tf.float32))
  92. biases = tf.get_variable( 'biases',
  93. shape=[n_classes],
  94. dtype=tf.float32,
  95. initializer=tf.constant_initializer( 0.1))
  96. softmax_linear = tf.add(tf.matmul(local4, weights), biases, name= 'softmax_linear')
  97. return softmax_linear
  98. #%%
  99. def losses(logits, labels):
  100. '''Compute loss from logits and labels
  101. Args:
  102. logits: logits tensor, float, [batch_size, n_classes]
  103. labels: label tensor, tf.int32, [batch_size]
  104. Returns:
  105. loss tensor of float type
  106. '''
  107. with tf.variable_scope( 'loss') as scope:
  108. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits\
  109. (logits=logits, labels=labels, name= 'xentropy_per_example')
  110. loss = tf.reduce_mean(cross_entropy, name= 'loss')
  111. tf.summary.scalar(scope.name+ '/loss', loss)
  112. return loss
  113. #%%
  114. def trainning(loss, learning_rate):
  115. '''Training ops, the Op returned by this function is what must be passed to
  116. 'sess.run()' call to cause the model to train.
  117. Args:
  118. loss: loss tensor, from losses()
  119. Returns:
  120. train_op: The op for trainning
  121. '''
  122. with tf.name_scope( 'optimizer'):
  123. optimizer = tf.train.AdamOptimizer(learning_rate= learning_rate)
  124. global_step = tf.Variable( 0, name= 'global_step', trainable= False)
  125. train_op = optimizer.minimize(loss, global_step= global_step)
  126. return train_op
  127. #%%
  128. def evaluation(logits, labels):
  129. """Evaluate the quality of the logits at predicting the label.
  130. Args:
  131. logits: Logits tensor, float - [batch_size, NUM_CLASSES].
  132. labels: Labels tensor, int32 - [batch_size], with values in the
  133. range [0, NUM_CLASSES).
  134. Returns:
  135. A scalar int32 tensor with the number of examples (out of batch_size)
  136. that were predicted correctly.
  137. """
  138. with tf.variable_scope( 'accuracy') as scope:
  139. correct = tf.nn.in_top_k(logits, labels, 1)
  140. correct = tf.cast(correct, tf.float16)
  141. accuracy = tf.reduce_mean(correct)
  142. tf.summary.scalar(scope.name+ '/accuracy', accuracy)
  143. return accuracy
  144. #%%



3.training.py

  1. #By @Kevin Xu
  2. #Youtube: https://www.youtube.com/channel/UCVCSn4qQXTDAtGWpWAe4Plw
  3. #
  4. #The aim of this project is to use TensorFlow to process our own data.
  5. # - input_data.py: read in data and generate batches
  6. # - model: build the model architecture
  7. # - training: train
  8. # I used Ubuntu with Python 3.5, TensorFlow 1.0*, other OS should also be good.
  9. # With current settings, 10000 traing steps needed 50 minutes on my laptop.
  10. # data: cats vs. dogs from Kaggle
  11. # Download link: https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
  12. # data size: ~540M
  13. # How to run?
  14. # 1. run the training.py once
  15. # 2. call the run_training() in the console to train the model.
  16. # Note:
  17. # it is suggested to restart your kenel to train the model multiple times
  18. #(in order to clear all the variables in the memory)
  19. # Otherwise errors may occur: conv1/weights/biases already exist......
  20. #%%
  21. import os
  22. import numpy as np
  23. import tensorflow as tf
  24. import input_data
  25. import model
  26. #%%
  27. N_CLASSES = 2
  28. IMG_W = 208 # resize the image, if the input image is too large, training will be very slow.
  29. IMG_H = 208
  30. BATCH_SIZE = 16
  31. CAPACITY = 2000
  32. MAX_STEP = 10000 # with current parameters, it is suggested to use MAX_STEP>10k
  33. learning_rate = 0.0001 # with current parameters, it is suggested to use learning rate<0.0001
  34. #%%
  35. def run_training():
  36. # you need to change the directories to yours.
  37. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'
  38. train_dir = 'D:/cat_VS_dog/cats_vs_dogs/data/train/' #My dir--20170727-csq
  39. #logs_train_dir = '/home/kevin/tensorflow/cats_vs_dogs/logs/train/'
  40. logs_train_dir = 'D:/cat_VS_dog/cats_vs_dogs/logs/train/'
  41. train, train_label = input_data.get_files(train_dir)
  42. train_batch, train_label_batch = input_data.get_batch(train,
  43. train_label,
  44. IMG_W,
  45. IMG_H,
  46. BATCH_SIZE,
  47. CAPACITY)
  48. train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
  49. train_loss = model.losses(train_logits, train_label_batch)
  50. train_op = model.trainning(train_loss, learning_rate)
  51. train__acc = model.evaluation(train_logits, train_label_batch)
  52. summary_op = tf.summary.merge_all()
  53. sess = tf.Session()
  54. train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
  55. saver = tf.train.Saver()
  56. sess.run(tf.global_variables_initializer())
  57. coord = tf.train.Coordinator()
  58. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  59. try:
  60. for step in np.arange(MAX_STEP):
  61. if coord.should_stop():
  62. break
  63. _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])
  64. if step % 50 == 0:
  65. print( 'Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc* 100.0))
  66. summary_str = sess.run(summary_op)
  67. train_writer.add_summary(summary_str, step)
  68. if step % 2000 == 0 or (step + 1) == MAX_STEP:
  69. checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
  70. saver.save(sess, checkpoint_path, global_step=step)
  71. except tf.errors.OutOfRangeError:
  72. print( 'Done training -- epoch limit reached')
  73. finally:
  74. coord.request_stop()
  75. coord.join(threads)
  76. sess.close()
  77. #%% Evaluate one image
  78. # when training, comment the following codes.
  79. from PIL import Image
  80. import matplotlib.pyplot as plt
  81. def get_one_image(train):
  82. '''Randomly pick one image from training data
  83. Return: ndarray
  84. '''
  85. n = len(train)
  86. ind = np.random.randint( 0, n)
  87. img_dir = train[ind]
  88. image = Image.open(img_dir)
  89. plt.imshow(image)
  90. image = image.resize([ 208, 208])
  91. image = np.array(image)
  92. return image
  93. def evaluate_one_image():
  94. '''Test one image against the saved models and parameters
  95. '''
  96. # you need to change the directories to yours.
  97. #train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'
  98. train_dir = 'D:/cat_VS_dog/cats_vs_dogs/data/train/'
  99. train, train_label = input_data.get_files(train_dir)
  100. image_array = get_one_image(train)
  101. with tf.Graph().as_default():
  102. BATCH_SIZE = 1
  103. N_CLASSES = 2
  104. image = tf.cast(image_array, tf.float32)
  105. image = tf.image.per_image_standardization(image)
  106. image = tf.reshape(image, [ 1, 208, 208, 3])
  107. logit = model.inference(image, BATCH_SIZE, N_CLASSES)
  108. logit = tf.nn.softmax(logit)
  109. x = tf.placeholder(tf.float32, shape=[ 208, 208, 3])
  110. # you need to change the directories to yours.
  111. #logs_train_dir = '/home/kevin/tensorflow/cats_vs_dogs/logs/train/'
  112. logs_train_dir = 'D:/cat_VS_dog/cats_vs_dogs/logs/train'
  113. saver = tf.train.Saver()
  114. with tf.Session() as sess:
  115. print( "Reading checkpoints...")
  116. ckpt = tf.train.get_checkpoint_state(logs_train_dir)
  117. if ckpt and ckpt.model_checkpoint_path:
  118. global_step = ckpt.model_checkpoint_path.split( '/')[ -1].split( '-')[ -1]
  119. saver.restore(sess, ckpt.model_checkpoint_path)
  120. print( 'Loading success, global_step is %s' % global_step)
  121. else:
  122. print( 'No checkpoint file found')
  123. prediction = sess.run(logit, feed_dict={x: image_array})
  124. max_index = np.argmax(prediction)
  125. if max_index== 0:
  126. print( 'This is a cat with possibility %.6f' %prediction[:, 0])
  127. else:
  128. print( 'This is a dog with possibility %.6f' %prediction[:, 1])
  129. #%%






代码的详细注释可参考文章:http://blog.csdn.net/hjxu2016/article/details/75305123


代码运行过程中显示几张图片时,彩色图片显示不正常:参考文章 http://blog.csdn.net/hjxu2016/article/details/75305123?locationNum=10&fps=1

的解释如下:
之前将图片转为float32了,因此这里imshow()出来的图片色彩会有点奇怪,因为本来imshow()是显示uint8类型的数据(灰度值在uint8类型下是0~255,转为float32后会超出这个范围,所以色彩有点奇怪),不过这不影响后面模型的训练。

该视频教程最后结果对猫狗分类误差还是存在的,测试几次,存在错将猫分成狗。



猜你喜欢

转载自blog.csdn.net/gdgyzl/article/details/80915749