TensorFlow 学习笔记(二)

1、
slim.get_variables(),slim.get_model_variables(),tf.global_variables(),tf.trainable_variables(),tf.all_variables()
[989,446,989,270,989]

  1. tf.GraphKeys.GLOBAL_VARIABLES:可以在多个设备上共享的变量
  2. tf.GraphKeys.TRAINABLE_VARIABLES:计算其梯度的变量
  3. model_variables():Model variables are trained or fine-tuned during learning and are loaded from a checkpoint during evaluation or inference。对于像global_step这种在训练时用到但是在推理时候用不到的变量则不是model变量
  4. regular_variables_and_model_variables = slim.get_variables()这条指令执行的是获得所有的regular变量和model变量,
  5. regular变量是什么呢::they can be saved to disk using a saver。其实就是global变量
  6. model variable和trainable variable的区别在前者含有moving_mean和moving_variance两个变量
import tensorflow as tf
import tensorflow.contrib.slim as slim
my_non_trainable = tf.get_variable("my_non_trainable",
                                   shape=(),
                                   trainable=False)
my_trainable = tf.get_variable("my_trainable",
                                   shape=(),
                                   trainable=True)
weights = slim.model_variable('weights',
                              shape=[10, 10, 3 , 3],
                              initializer=tf.truncated_normal_initializer(stddev=0.1),
                              regularizer=slim.l2_regularizer(0.05),
                              device='/CPU:0')
print('slim.get_variables',slim.get_variables())
print('slim.get_model_variables',slim.get_model_variables())
print('tf.global_variables',tf.global_variables())
print('tf.trainable_variables',tf.trainable_variables())

这里写图片描述

  • 如果建立一个变量是的trainable=False的话,那么该变量不会出现在trainable的变量集合中,也不会出现在model变量集合中,但是都会是global变量
  • 只要是tf.get_variable的方式建立的变量都不会成为model variable

那么什么样子的变量才有资格成为model variable?

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib import layers as layers_lib
input = tf.zeros([1,64,64,3])
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')
net1 = layers_lib.conv2d(
        input,
        128,
        [3,3],
        stride=1,
        padding='SAME',
        scope='conv2-1')
print('slim.get_variables',slim.get_variables())
print('slim.get_model_variables',slim.get_model_variables())
print('tf.global_variables',tf.global_variables())
print('tf.trainable_variables',tf.trainable_variables())

print('slim.get_variables',len(slim.get_variables()))
print('slim.get_model_variables',len(slim.get_model_variables()))
print('tf.global_variables',len(tf.global_variables()))
print('tf.trainable_variables',len(tf.trainable_variables()))

这里写图片描述

  • 由实验可以发现,通过建立层之类的操作获得的变量,不论是slim还是contrib.layers创立的,都会是model variable,也是trainable variable,也是global variable
  • 由slim.model_variable创建的变量是model variable,通过variable和get_variable方式建立的都不是model_variable

3、如何打印加在的模型中的都含有哪些变量
通过打印可以知道,只要saver保存的时候没有指明保存哪些变量的话,最后保存的是所有的global variable。
1)

import os
from tensorflow.python import pywrap_tensorflow

path = "/home/lzhou/tf-ll/logs/tmp/model_dump/snapshot_2.ckpt"
reader = pywrap_tensorflow.NewCheckpointReader(path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

2)

from tensorflow.python.tools import inspect_checkpoint as chkp
path = "/home/lzhou/tf-ll/logs/tmp/model_dump/snapshot_2.ckpt"
chkp.print_tensors_in_checkpoint_file(path, tensor_name='', all_tensors=True)

猜你喜欢

转载自blog.csdn.net/u013548568/article/details/80080288