tensorflow编程实践:结构化你的模型

版权声明:本文为博主原创文章,保留著作权,未经博主允许不得转载。 https://blog.csdn.net/LoseInVain/article/details/82085185

这篇文章是翻译自《Structuring Your TensorFlow Models》,这篇文章主要描述了在TensorFlow中如何更好地结构化地定义你的网络模型,以便于更好地扩展和调试。我们将会发现,采用这种构建方法,可以将整个模型变得模块化。
如果对这篇文章有着更好的建议,请联系我,谢谢。


在TensorFlow中定义你的模型很容易导致一个臃肿复杂,难以维护的代码,这个是一个很糟糕的事情,因为深度模型本身就难以调试和差错,因此在模型的搭建过程中应该尽可能的用模块化的方法去搭建模型和结构化模型。如何以一种高可读性和可复用性的手段结构化你的代码呢?如果你急于求成,可以直接参考working example gist。也可以参考这篇关于如何在TensorFlow中实现快速原型的文章fast prototyping,结构化模型的基本思想就在这里描述了。

定义计算图

给每一个模型定义一个类是很科学而且高效的做法。那么,用什么作为这个类的接口呢?通常来说,你的模型会和一些输入数据目标数据(target)占位符(placeholder)存在关联,毕竟你需要以此来喂(feed)数据,并且提供一个接口给主程序训练(training),评估(evaluation)和推理(inference)。也就是说,我们这个类,至少要提供这几种接口才是一个较为完整的深度网络模型类:

  • 喂数据的接口, 如输入和目标的占位符等。
  • 提供给主程序调用的训练,评估和推理接口
  • (可选)一个网络中间输出结果,这种用法类似与pytorch和keras的做法,把一些成熟的网络模块化后直接输出中间处理结果,以便于更好的模块化。

我们观察一下例子:

class Model:
    def __init__(self, data, target):
        data_size = int(data.get_shape()[1])
        target_size = int(target.get_shape()[1])
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(data, weight) + bias
        self._prediction = tf.nn.softmax(incoming)
        cross_entropy = -tf.reduce_sum(target, tf.log(self._prediction))
        self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
        mistakes = tf.not_equal(
            tf.argmax(target, 1), tf.argmax(self._prediction, 1))
        self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))

    @property
    def prediction(self):
        return self._prediction

    @property
    def optimize(self):
        return self._optimize

    @property
    def error(self):
        return self._error

这是一个基本的关于如何在TF中定义模型的代码,然而,这个代码还是存在一些问题的。最明显的问题是,这整个模型的计算图都定义在了一个函数里面,也就是构造器。如果你的模型变得很复杂,这个构造器将变得异常臃肿,这样既不是可读性强的,也不是可复用性强的编程习惯。(译者:而且,这里还有一个问题,如果按照以上的代码去进行整个模型的图的构建,那么不管我们是不是需要用到整个模型的每个子模型,他都会给我一股脑地预先构建出来。这样其实不是一个很好的方案,因为很多时候,模型很大,有很多分支,而且训练阶段也有不止一个阶段,每个阶段可能用到不同的分支。因此并没有必要一股脑把所有分支给构建出来,用本文的思路可以实现很好的结构化模型。

利用类属性(Properties)去结构化你的模型吧

简单地将你的构建计算图的代码从构造器中分离出来不能解决任何问题,因为每一次这个函数被调用的时候,这个计算图都会添加新的节点。这并不是我们想要的,我们需要的是这个计算图只会在我们第一次调用某个模块的构建的时候在整个计算图中添加新的节点,在第二次或者更多次的时候,不需要其再次添加相同的节点,这个被称之为惰性加载(lazy-loading)

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self._prediction = None
        self._optimize = None
        self._error = None

    @property
    def prediction(self):
        if not self._prediction:
            data_size = int(self.data.get_shape()[1])
            target_size = int(self.target.get_shape()[1])
            weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
            bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
            incoming = tf.matmul(self.data, weight) + bias
            self._prediction = tf.nn.softmax(incoming)
        return self._prediction

    @property
    def optimize(self):
        if not self._optimize:
            cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
            optimizer = tf.train.RMSPropOptimizer(0.03)
            self._optimize = optimizer.minimize(cross_entropy)
        return self._optimize

    @property
    def error(self):
        if not self._error:
            mistakes = tf.not_equal(
                tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
            self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
        return self._error

这个比第一个例子好多了,你的代码现在可以在类中的方法中结构化,因此你只需要单独地关注某个部分就可以了。然而,这个代码为了实现这个惰性加载的逻辑,额外多出了很多判断的分子,这个仍然是有些臃肿的,我们利用python中自带的修饰器的性质,可以进行一些修改。

惰性类属性修饰器

python是一种很灵活的语言,在下一个例子中,我将展示给你如何从上一个例子的代码中除去冗余的代码。我们将会使用一个表现得像是@property但是只会实际上调用这个函数一次的修饰器。如果你对如何定制修饰器不熟悉,也许你可以先参考这篇文章python修饰器教程。采用这种方法,可以有效地减少一些为了实现惰性加载的逻辑而额外多出的代码,如if not self._error:这一部分。我们观察一下我们需要的修饰器代码:

import functools

def lazy_property(function):
    attribute = '_cache_' + function.__name__
    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

采用这个修饰器,我们的例子可以被简化为下面的代码:

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self.prediction
        self.optimize
        self.error

    @lazy_property
    def prediction(self):
        data_size = int(self.data.get_shape()[1])
        target_size = int(self.target.get_shape()[1])
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(self.data, weight) + bias
        return tf.nn.softmax(incoming)

    @lazy_property
    def optimize(self):
        cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
        optimizer = tf.train.RMSPropOptimizer(0.03)
        return optimizer.minimize(cross_entropy)

    @lazy_property
    def error(self):
        mistakes = tf.not_equal(
            tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
        return tf.reduce_mean(tf.cast(mistakes, tf.float32))

注意到我们在关于图构建的函数中都是用了这个修饰器的。需要另外注意的是,当你运行tf.initialize_variables()以初始化变量的时候,务必留意你是否已经定义了这个计算图,否则是会报错的。

更进一步,用名字空间(Scopes)组织计算图

我们现在有了一个更为简洁干净的方法去定义我们的模型,但是这个计算图仍然还是太过于拥挤了,如果你曾经用tensorboard可视化过整个计算图,你肯定明白我说的是什么意思,整个计算图将会包括很多小节点之间的互连。解决这个问题可以通过用一个“包裹”把这些互连的内容给打包起来,通过利用tf.name_scope()或者tf.variable_scope()你就可以实现这个功能,这两者的具体区别我们以后再谈,姑且看下这两个函数怎么使用。在计算图中,你可以指定某些节点被聚合在一起。我们而且,可以让我们的修饰器自动地实现这个功能,而不需要每个都人工手动完成。

import functools

def define_scope(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            with tf.variable_scope(function.__name):
                setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

因为每个模块都是有功能性的特异性的,因此给每个模块一个新的名字。除此之外,这个模型和之前那个完全一样。
我们甚至可以走得更远一点,我们可以让@define_scope修饰器可以传递参数给tf.variable_scope(),比如说给定义一个默认的初始化之类的。如果你对这方面的感兴趣,请移步blog_tensorflow_scope_decorator.py

猜你喜欢

转载自blog.csdn.net/LoseInVain/article/details/82085185
今日推荐