tensorflow2.0入门实例三(模型加载)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_43162240/article/details/102662593

  做东西,最重要的就是动手了,所以这篇文章动手跑了一个fcn32s和fcn8s以及deeplab v3+的例子,这个例子的数据集选用自动驾驶相关竞赛的kitti数据集, FCN8s在训练过程中用tensorflow2.0自带的评估能达到91%精确率, deeplab v3+能达到97%的准确率。这篇文章适合入门级选手,在文章中不再讲述fcn的结构,直接百度就可以搜到。
  文章使用的是tensorflow2.0框架,该框架集成了keras,在模型的训练方面极其简洁,不像tf1.x那么复杂,综合其他深度学习框架,发现这个是最适合新手使用的一种。
文章中用到的库函数,参数等均可在tensorflow2.0 api中查找到。
  文章的代码在github可以获取,地址:https://github.com/fengshilin/tf2.0-FCN

  文章的结构如下:

  1. 数据下载与分析
  2. 数据预处理(重点在label的预处理)
  3. 模型加载
  4. 模型建立(FCN与Deeplab)
  5. 模型训练与测试

3.模型加载

  很多模型需要以基础模型为backbone来构建,甚至用它们已经训练好的参数来做训练,收敛会更快,tensorflow2.0的tf.keras.applications模块集成了很多模型,常用的有vgg16与resnet50。文章中的FCN就需要用到vgg16作为backbone。
  下面我们看怎么加载vgg16模型,并配置参数,参数如下:

  • weights:None表示没有指定权重,对网络参数进行随机初始化. ‘imagenet’ 表示加载imagenet与训练的网络权重
  • include_top:表示是否加载全连接层;
  • input_tensor表示模型的输入,主要涉及输入的影像的长宽与通道数
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Input

# 加载模型,参数分别表示
vgg16_model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=False, input_tensor=Input(shape=(160, 160, 3)))

vgg16_model.summary()  # 这一行是为了打印模型结构

打印出的模型结构如下:

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 160, 160, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 160, 160, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 80, 80, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 80, 80, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 80, 80, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 40, 40, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 40, 40, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 40, 40, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 40, 40, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 20, 20, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 20, 20, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 10, 10, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 10, 10, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 10, 10, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 10, 10, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 5, 5, 512)         0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________

模型层级调用, 一共有19层,所以可以选择vgg16_model.layers[i], i从0到18

layer0 = vgg16_model.layers[0]
print(layer0)

输出

<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fda248c9ba8>

猜你喜欢

转载自blog.csdn.net/weixin_43162240/article/details/102662593