tensorflow 如何获取模型中想要的张量

当我们想要改造或者利用某一预训练模型来完成一些其它任务时,一个常用且必备的操作是从指定模型中获取到我们感兴趣的张量(tensor)。

例如我想使用一个已经训练好的CNN模型中间的某一层的结果作为特征向量来完成另一个相关任务,就需要这样的操作。

如何做到?很简单,只需两步:

1.获取到感兴趣张量的名字.

2.使用get_tensor_by_name函数获取

下面详细说明下

1.获取到感兴趣张量的名字

我们知道,张量tensor是由操作operation运行得到的。

因此实际上,你要做的就是,获取到能够生成你想要的tensor的op的名字,无论你用什么方法。

通常情况,有两种方法可以获取到op_name。

1)如果模型是你自己搭建的,你直接可以通过查看搭建网络的源代码来确认op_name。

因此,对于具有关键含义的op,创建时为其自定义一个合适的名字,或者创建scope分组是一个良好的编码习惯。

扫描二维码关注公众号,回复: 8916960 查看本文章

2)如果你对拿到的模型毫无头绪,参考并修改下面的代码打印出所有的张量信息

with tf.Graph().as_default():
    config = tf.ConfigProto()
    sess = tf.Session(config = config)
    with sess.as_default():
        meta_path = checkpoint_path + '.meta'
        saver = tf.train.import_meta_graph(meta_path)
        saver.restore(sess, checkpoint_path)

        op_list = sess.graph.get_operations()
        for op in op_list:
            print(op.name)
            print(op.values())

上面的代码会输出类似如下的结果:

...
...
resnet_v1_50/pool5
(<tf.Tensor 'resnet_v1_50/pool5:0' shape=(?, 1, 1, 2048) dtype=float32>,)
Logits/Squeeze
(<tf.Tensor 'Logits/Squeeze:0' shape=(?, 2048) dtype=float32>,)
...
...

剩下需要做的就是基于你对模型结构的理解,找到你想要的张量对应的op_name。

例如,我想要的便是resnet50模型最后一个pooling层摊平后的张量,对应的op_name为‘Logits/Squeeze:0’

2.使用get_tensor_by_name函数获取想要的张量

获取到了op_name,剩下的事情仅仅是调用get_tensor_by_name函数来拿到它,例如:

embedding = tf.get_default_graph().get_tensor_by_name('Logits/Squeeze:0')
发布了87 篇原创文章 · 获赞 325 · 访问量 52万+

猜你喜欢

转载自blog.csdn.net/u011583927/article/details/90668687