机器学习: 利用 Tensorflow 和预训练模型提取特征-- Mobilenet V1

传统的 CV 问题,一般把特征提取和分类模型的构建训练分成两个步骤,CNN 可以把这两者合在一个网络里,目前很多实验证明,利用大量数据训练过的 CNN 可以用作很好的特征提取器,类似一种特征迁移。

今天介绍一下,如何利用 Tensorflow 和 预先训练好的模型,做特征提取,我们可以用 TensorFlow GitHub 官网上的预训练模型来做特征提取:

https://github.com/tensorflow/models/tree/master/research/slim

预训练模型,是用 ImageNet 训练过的,网站上有 VGG, ResNet, 以及 Inception 等几种不同类似的训练模型:

今天我们利用一个轻量级的模型 Mobilenet_v1 来做特征提取,首先下载好训练好的模型:mobilenet_v1_1.0_224 ckpt

利用 ckpt 我们还可以查看整个网络的结构,以及每一层的 feature map

首先我们载入相应的模块:

import tensorflow as tf
import numpy as np
import glob
from nets import mobilenet_v1
slim = tf.contrib.slim

然后定义一个函数: 这个函数可以帮我们解析图片路径,读取图片,做预处理,然后转成 tensor 形式:

def mobi_parse_fun(x_in, y_label=1):
    img_path = tf.read_file(x_in)
    img_decode = tf.io.decode_jpeg(img_path, channels=3)
    img = tf.image.resize_images(img_decode, [224, 224])
    img = tf.cast(img, tf.float32) / 127.5 - 1.0
    
    return img, y_label

接下来,我们可以利用 TensorFlow 中的 dataset 模块,处理数据:

X_in = tf.placeholder(tf.string, None)
# Y_in = tf.placeholder(tf.int32, None)
train_data = tf.data.Dataset.from_tensor_slices((X_in))
train_data = train_data.map(mobi_parse_fun)
train_data = train_data.batch(1)

iter_ = tf.data.Iterator.from_structure(train_data.output_types,
                                        train_data.output_shapes)
x_batch, y_batch = iter_.get_next()
train_init_op = iter_.make_initializer(train_data)

然后调用网络的定义,并且加载模型所在的路径:

with tf.contrib.slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
    logits, endpoints = mobilenet_v1.mobilenet_v1(x_batch, num_classes=1001)

ckpt_path = 'D:\Python_Code\mobilenet_v1_1.0_224\mobilenet_v1_1.0_224.ckpt'
saver = tf.train.Saver()

我们获取图片的存储路径

img_path = 'F:\cute\*.jpg'
img_list = glob.glob(img_path)

接下来,就可以定义一个 session,并且把模型加载进来:

with tf.Session() as sess:

    saver.restore(sess, ckpt_path)
    
    ## 查看网络每一层的参数:
    print('print the trainable parameters: ')
    for eval_ in tf.trainable_variables():
        print(eval_.name)
        w_val = sess.run(eval_.name)
        print(w_val.shape)
    
    sess.run(train_init_op, feed_dict={X_in: img_list})
    
    #---------------------------------------------
    #---------------------------------------------
    # 查看每一层的 feature map,
    key_name = endpoints.keys()
    print('print the feature maps: ')
    for name_ in key_name:
        print(name_)
        feat_map = sess.run(endpoints[name_])
        print(feat_map.shape)

   
    fc_map = endpoints['AvgPool_1a']
    fc_feat = tf.squeeze(fc_map, [1, 2])
    
    for img_name in img_list:
        print(img_name)
        
        x_bat, y_bat = sess.run([x_batch, y_batch])
        print(x_bat.shape, y_bat.shape)
        
        fc_feature = sess.run([fc_feat])
        print(fc_feature[0].shape)
        
        break

我们可以查看 Mobinet_V1 的网络结构如下:

MobilenetV1/Conv2d_0/weights:0 (3, 3, 3, 32)
MobilenetV1/Conv2d_0/BatchNorm/gamma:0 (32,)
MobilenetV1/Conv2d_0/BatchNorm/beta:0 (32,)
MobilenetV1/Conv2d_1_depthwise/depthwise_weights:0 (3, 3, 32, 1)
MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma:0 (32,)
MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta:0 (32,)
MobilenetV1/Conv2d_1_pointwise/weights:0 (1, 1, 32, 64)
MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma:0 (64,)
MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta:0 (64,)
MobilenetV1/Conv2d_2_depthwise/depthwise_weights:0 (3, 3, 64, 1)
MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma:0 (64,)
MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta:0 (64,)
MobilenetV1/Conv2d_2_pointwise/weights:0 (1, 1, 64, 128)
MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_3_depthwise/depthwise_weights:0 (3, 3, 128, 1)
MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_3_pointwise/weights:0 (1, 1, 128, 128)
MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_4_depthwise/depthwise_weights:0 (3, 3, 128, 1)
MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_4_pointwise/weights:0 (1, 1, 128, 256)
MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_5_depthwise/depthwise_weights:0 (3, 3, 256, 1)
MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_5_pointwise/weights:0 (1, 1, 256, 256)
MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_6_depthwise/depthwise_weights:0 (3, 3, 256, 1)
MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_6_pointwise/weights:0 (1, 1, 256, 512)
MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_7_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_7_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_8_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_8_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_9_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_9_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_10_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_10_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_11_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_11_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_12_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_12_pointwise/weights:0 (1, 1, 512, 1024)
MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Conv2d_13_depthwise/depthwise_weights:0 (3, 3, 1024, 1)
MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Conv2d_13_pointwise/weights:0 (1, 1, 1024, 1024)
MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Logits/Conv2d_1c_1x1/weights:0 (1, 1, 1024, 1001)
MobilenetV1/Logits/Conv2d_1c_1x1/biases:0 (1001,)

而网络中的 feature map如下:

print the feature maps:
Conv2d_0 (1, 112, 112, 32)
Conv2d_1_depthwise (1, 112, 112, 32)
Conv2d_1_pointwise (1, 112, 112, 64)
Conv2d_2_depthwise (1, 56, 56, 64)
Conv2d_2_pointwise (1, 56, 56, 128)
Conv2d_3_depthwise (1, 56, 56, 128)
Conv2d_3_pointwise (1, 56, 56, 128)
Conv2d_4_depthwise (1, 28, 28, 128)
Conv2d_4_pointwise (1, 28, 28, 256)
Conv2d_5_depthwise (1, 28, 28, 256)
Conv2d_5_pointwise (1, 28, 28, 256)
Conv2d_6_depthwise (1, 14, 14, 256)
Conv2d_6_pointwise (1, 14, 14, 512)
Conv2d_7_depthwise (1, 14, 14, 512)
Conv2d_7_pointwise (1, 14, 14, 512)
Conv2d_8_depthwise (1, 14, 14, 512)
Conv2d_8_pointwise (1, 14, 14, 512)
Conv2d_9_depthwise (1, 14, 14, 512)
Conv2d_9_pointwise (1, 14, 14, 512)
Conv2d_10_depthwise (1, 14, 14, 512)
Conv2d_10_pointwise (1, 14, 14, 512)
Conv2d_11_depthwise (1, 14, 14, 512)
Conv2d_11_pointwise (1, 14, 14, 512)
Conv2d_12_depthwise (1, 7, 7, 512)
Conv2d_12_pointwise (1, 7, 7, 1024)
Conv2d_13_depthwise (1, 7, 7, 1024)
Conv2d_13_pointwise (1, 7, 7, 1024)
AvgPool_1a (1, 1, 1, 1024)
Logits (1, 1001)
Predictions (1, 1001)

我们看到,最靠近 FC 的是 AvgPool_1a 这层的 feature map,所以我们将这层的 feature map抽取出来,就可以当成我们输入图像的特征来用了。

发布了219 篇原创文章 · 获赞 211 · 访问量 113万+

猜你喜欢

转载自blog.csdn.net/shinian1987/article/details/88776253