Deep Learning Practical Chapter (10) -- TensorFlow Learning Road (7)

window of knowledge

PyTorch is an open source Python machine learning library, based on Torch, for applications such as natural language processing.

In January 2017, PyTorch was launched by Facebook Artificial Intelligence Research Institute (FAIR) based on Torch. It is a Python-based sustainable computing package that provides two advanced features: 1. Tensor computing (such as NumPy) with powerful GPU acceleration. 2. Deep neural network including automatic derivation system.

The predecessor of PyTorch is Torch, and its bottom layer is the same as the Torch framework, but a lot of content has been rewritten in Python, which is not only more flexible, supports dynamic graphs, but also provides a Python interface. Developed by the Torch7 team, it is a Python-first deep learning framework that not only enables powerful GPU acceleration, but also supports dynamic neural networks.

PyTorch can be regarded as numpy with GPU support, and it can also be regarded as a powerful deep neural network with automatic derivation function. In addition to Facebook, it has been adopted by institutions such as Twitter, CMU, and Salesforce.

review

In last week's article, we learned to integrate all the codes (data preprocessing, network model, training code), and then carried out the actual training. We must know that the training results of the neural network are small except to know the quality of the model and In addition to effectiveness, we also need to consider the actual test of the trained model, and it may also need to be deployed as an application later. Of course, it will not be deployed directly. We also need to consider optimization, compression, pruning and other issues. 

1. Model prediction

Implementation steps:

1. Save the model during training

2. Write test code (data processing, model calling, data testing)

4. Output the model result and map it to the real label

1. Save the model during training

#在训练之前添加
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()

After each batch is trained, the test of the entire verification set is started (now it is generally verified after one epoch is trained). After the verification set test, if the accuracy of the test is greater than the previous test and greater than 80%, the model is considered to be saved, that is, the final Save the best model.

 if avg_test_acc > pre_test_acc and avg_test_acc > 0.80:
checkpoint_path = os.path.join(logs_checkpoint,
 'model.ckpt')
saver.save(sess,

2. Test code

1. Data preprocessing:

This place is the same as when training

# 获取一张图片
def get_one_image(img_dir):
    # 输入参数:train,训练图片的路径
    # 返回参数:image,从训练图片中随机抽取一张图片
    #print("train", train)
    #n = len(train)
    #ind = np.random.randint(0, n)
    #img_dir = train[ind]  # 随机选择测试的图片
    # img_dir = train


    img = Image.open(img_dir)
    #plt.imshow(img)
    #imag = img.resize([150, 150])  # 由于图片在预处理阶段以及resize,因此该命令可略
    imge = tf.image.resize_images(img, (150, 150))
    image = tf.reshape(imge, [1, 150, 150, 3])
    #image = np.array(imge)


    image = image/255
    image = tf.cast(image, tf.float32)


    return image


2. Model call

In fact, it is to restore the parameters of the saved model and import them into the current network for testing.

The current network only performs forward propagation and does not perform back propagation.

saver = tf.train.Saver()


with tf.Session() as sess:
img_array = sess.run(image_array)


print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('Loading success, global_step is %s' % global_step)
else:

3. Data test

# 测试图片
def evaluate_one_image(image_array):
global graph
graph = tf.get_default_graph()
with graph.as_default():
  BATCH_SIZE = 1
  N_CLASSES = 2
  #image = tf.cast(image_array, tf.float32)


  x = tf.placeholder(tf.float32, shape=[1,150, 150, 3])


  logit = model.inference(x, BATCH_SIZE, N_CLASSES,1)


  logit = tf.nn.softmax(logit)

4. Output result:

prediction = sess.run(logit,feed_dict={x: img_array})
max_index = np.argmax(prediction)
# print(max_index)
# 标签映射可以选择字典或者列表
label_dict = {0: 'cat', 1: 'dog'}
label_list = ['cat','dog']
print("模型的输出为{},对应的真实标签为:{}".format(max_index,label_list[max_index]))

Full test code:

Actual Forecast Display

It can be seen that what we read is the picture of the dog in the test, and then the predicted label of the network is 1. The label given to the dog at the beginning is 1, that is, the actual label is mapped to dog, and the prediction is correct.

epilogue

This sharing is over. It can be regarded as a complete project of the image classification project. From data processing to network construction, to training, to calling the model for prediction, we have shared it, and at the same time commented on the code details. I believe smart You can definitely understand, if you have any doubts, please feel free to backstage.

     Although this project is over, I believe that there are more or less things that people don't understand. Whether it is data processing or network construction, etc., it may not be that simple. It doesn't matter. Next time, the editor will focus on this project. It can be regarded as a summary of the image classification project. At the same time, veterans are welcome to ask a lot of questions to promote us to make progress together.

     Have a great weekend, see you next time!

Editor: Layman Yueyi|Review: Layman Xiaoquanquan

Wonderful review of the past

Deep Learning Practical Chapter (9) -- TensorFlow Learning Road (6)

Deep Learning Practical Chapter (8) -- TensorFlow Learning Road (5)

Deep Learning Practical Chapter (7) -- TensorFlow Learning Road (4)

What have we done in the past time:

[Year-end summary] 2021, bid farewell to the old and welcome the new

[Year-end Summary] Saying goodbye to the old and welcoming the new, 2020, let's start again

Scan code to follow us

Advanced IT Tour

praise me when i see you

Guess you like

Origin blog.csdn.net/xyl666666/article/details/118077730