深度学习实战(三) 构建DNN来比较两个数字图像是否相等

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/junjun150013652/article/details/81267305

Pretraining on an auxiliary task.

  1. In this exercise you will build a DNN that compares two MNIST digit images and predicts whether they represent the same digit or not. Then you will reuse the lower layers of this network to train an MNIST classifier using very little training data. Start by building two DNNs (let’s call them DNN A and B), both similar to the one you built earlier but without the output layer: each DNN should have five hidden layers of 100 neurons each, He initialization, and ELU activation. Next, add a single output layer on top of both DNNs. You should use TensorFlow’s concat() function with axis=1 to concatenate the outputs of both DNNs along the horizontal axis, then feed the result to the output layer. This output layer should contain a single neuron using the logistic acti‐ vation function.

  2. Split the MNIST training set in two sets: split #1 should containing 55,000 images, and split #2 should contain contain 5,000 images. Create a function that generates a training batch where each instance is a pair of MNIST images picked from split #1. Half of the training instances should be pairs of images that belong to the same class, while the other half should be images from dif‐ ferent classes. For each pair, the training label should be 0 if the images are from the same class, or 1 if they are from different classes.

  3. Train the DNN on this training set. For each image pair, you can simultane‐ ously feed the first image to DNN A and the second image to DNN B. The whole network will gradually learn to tell whether two images belong to the same class or not.

import tensorflow as tf
import numpy as np
from datetime import datetime
import os

he_init = tf.variance_scaling_initializer()

def generate_batch(images, labels, batch_size):
    size1 = batch_size // 2
    size2 = batch_size - size1
    if size1 != size2 and np.random.rand() > 0.5:
        size1, size2 = size2, size1
    X = []
    y = []
    while len(X) < size1:
        rnd_idx1, rnd_idx2 = np.random.randint(0, len(images), 2)
        if rnd_idx1 != rnd_idx2 and labels[rnd_idx1] == labels[rnd_idx2]:
            X.append(np.array([images[rnd_idx1], images[rnd_idx2]]))
            y.append([1])
    while len(X) < batch_size:
        rnd_idx1, rnd_idx2 = np.random.randint(0, len(images), 2)
        if labels[rnd_idx1] != labels[rnd_idx2]:
            X.append(np.array([images[rnd_idx1], images[rnd_idx2]]))
            y.append([0])
    rnd_indices = np.random.permutation(batch_size)
    return np.array(X)[rnd_indices], np.array(y)[rnd_indices]

def dnn(inputs, n_hidden_layers=5, n_neurons=100, name=None,
        activation=tf.nn.elu, initializer=he_init):
    with tf.variable_scope(name, "dnn"):
        for layer in range(n_hidden_layers):
            inputs = tf.layers.dense(inputs, n_neurons, activation=activation,
                                     kernel_initializer=initializer,
                                     name="hidden%d" % (layer + 1))
        return inputs

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]

X_test, y_test = generate_batch(X_test, y_test, batch_size=len(X_test))

n_inputs = 28 * 28 # MNIST
X = tf.placeholder(tf.float32, shape=(None, 2, n_inputs), name="X")
X1, X2 = tf.unstack(X, axis=1)
y = tf.placeholder(tf.int32, shape=[None, 1])

dnn1 = dnn(X1, name="DNN_A")
dnn2 = dnn(X2, name="DNN_B")
dnn_outputs = tf.concat([dnn1, dnn2], axis=1)

hidden = tf.layers.dense(dnn_outputs, units=10, activation=tf.nn.elu, kernel_initializer=he_init)
logits = tf.layers.dense(hidden, units=1, kernel_initializer=he_init)
y_proba = tf.nn.sigmoid(logits)
y_pred = tf.cast(tf.greater_equal(logits, 0), tf.int32)

y_as_float = tf.cast(y, tf.float32)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_as_float, logits=logits)
loss = tf.reduce_mean(xentropy)

learning_rate = 0.01
momentum = 0.95

optimizer = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)
training_op = optimizer.minimize(loss)

y_pred_correct = tf.equal(y_pred, y)
accuracy = tf.reduce_mean(tf.cast(y_pred_correct, tf.float32))

init = tf.global_variables_initializer()
saver = tf.train.Saver()

n_epochs = 100
batch_size = 500

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(len(X_train) // batch_size):
            X_batch, y_batch = generate_batch(X_train, y_train, batch_size)
            loss_val, _ = sess.run([loss, training_op], feed_dict={X: X_batch, y: y_batch})
        acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
        print(epoch, "Test accuracy:", acc_test)

    save_path = saver.save(sess, "./my_digit_comparison_model.ckpt")

log info:

0 Test accuracy: 0.5043
1 Test accuracy: 0.4917
2 Test accuracy: 0.4902
3 Test accuracy: 0.4927
4 Test accuracy: 0.518
5 Test accuracy: 0.554
6 Test accuracy: 0.6524
7 Test accuracy: 0.6931
8 Test accuracy: 0.7264
9 Test accuracy: 0.7996
10 Test accuracy: 0.8169
11 Test accuracy: 0.8372
12 Test accuracy: 0.8469
13 Test accuracy: 0.8555
14 Test accuracy: 0.8621
15 Test accuracy: 0.8702
16 Test accuracy: 0.877
17 Test accuracy: 0.8859
18 Test accuracy: 0.8919
19 Test accuracy: 0.8929
20 Test accuracy: 0.8987
21 Test accuracy: 0.9086
22 Test accuracy: 0.9147
23 Test accuracy: 0.919
24 Test accuracy: 0.9212
25 Test accuracy: 0.9259
26 Test accuracy: 0.9301
27 Test accuracy: 0.9333
28 Test accuracy: 0.9354
29 Test accuracy: 0.9382
30 Test accuracy: 0.9394
31 Test accuracy: 0.943
32 Test accuracy: 0.9464
33 Test accuracy: 0.9474
34 Test accuracy: 0.9496
35 Test accuracy: 0.9536
36 Test accuracy: 0.9492
37 Test accuracy: 0.9532
38 Test accuracy: 0.955
39 Test accuracy: 0.9543
40 Test accuracy: 0.9552
41 Test accuracy: 0.9538
42 Test accuracy: 0.9579
43 Test accuracy: 0.9579
44 Test accuracy: 0.9587
45 Test accuracy: 0.9567
46 Test accuracy: 0.9605
47 Test accuracy: 0.9631
48 Test accuracy: 0.9616
49 Test accuracy: 0.963
50 Test accuracy: 0.9627
51 Test accuracy: 0.9632
52 Test accuracy: 0.9638
53 Test accuracy: 0.9654
54 Test accuracy: 0.9645
55 Test accuracy: 0.9633
56 Test accuracy: 0.9657
57 Test accuracy: 0.966
58 Test accuracy: 0.9663
59 Test accuracy: 0.9657
60 Test accuracy: 0.9688
61 Test accuracy: 0.9664
62 Test accuracy: 0.9705
63 Test accuracy: 0.9692
64 Test accuracy: 0.9708
65 Test accuracy: 0.9687
66 Test accuracy: 0.9675
67 Test accuracy: 0.9696
68 Test accuracy: 0.9688
69 Test accuracy: 0.9689
70 Test accuracy: 0.9701
71 Test accuracy: 0.9702
72 Test accuracy: 0.9718
73 Test accuracy: 0.9731
74 Test accuracy: 0.9711
75 Test accuracy: 0.9719
76 Test accuracy: 0.9733
77 Test accuracy: 0.9709
78 Test accuracy: 0.9728
79 Test accuracy: 0.9722
80 Test accuracy: 0.9714
81 Test accuracy: 0.972
82 Test accuracy: 0.973
83 Test accuracy: 0.973
84 Test accuracy: 0.9737
85 Test accuracy: 0.9734
86 Test accuracy: 0.9729
87 Test accuracy: 0.9751
88 Test accuracy: 0.974
89 Test accuracy: 0.9725
90 Test accuracy: 0.9747
91 Test accuracy: 0.9746
92 Test accuracy: 0.9735
93 Test accuracy: 0.9733
94 Test accuracy: 0.9749
95 Test accuracy: 0.9752
96 Test accuracy: 0.976
97 Test accuracy: 0.9756
98 Test accuracy: 0.9767
99 Test accuracy: 0.9762

猜你喜欢

转载自blog.csdn.net/junjun150013652/article/details/81267305