TENSORFLOW:ESTIMATORS

Estimators

Estimators是TensorFlow的高层API,它大大简化了机器学习的编程。Estimator封装了以下功能:
– 模型训练
– 模型评价
– 模型预测
– 模型导出

TensorFlow提供了一些Estimator,你也可以开发自己的Estimator,不论是TensorFlow提供的还是你自定义的都是tf.estimator.Estimator的子类。

tf.contrib.learn.Estimator 已经废弃了,请不要再使用了

Estimator的优势


Estimator有以下好处:
– 你可以在本地或者分布式环境下运行基于Estimator的模型,而不需要改变你的模型。同样,在CPU,GPU或者TPU下运行你的模型也不需要做改动。
– 和其他开发人员共享你的模型变得简单。
– 大大方便你的开发,比你从一些TensorFlow底层API开始开发容易的多。
– Estimators本来就是基于tf.layers开发的,这样你想做些改动也比较容易。
– Estimator会自己build graph,这样就不需要你自己构建了。
– Estimator 提供了一个安全的分布式的训练环,它控制了如何和怎样去完成下边的任务
* 构建graph
* 初始化变量
* 启动队列
* 异常处理
* 创建checkpoint文件和错误恢复
* 为TensorBoard保存概要信息
用Estimator写一个应用的时候,你必须把数据输入流和model分开。这样你可以方便的来更换不同的数据源来实验。

预置的Estimators


预置的 Estimators让你基于高层的封装的TensorFlow API进行开发,你不用去自己创建计算流图和session,因为Estimator会处理所有的流程。也就是预置的Estimators会自己创建和管理Graph和Session对象。而且它可以让你通过很小的改动就可以尝试不同的模型架构。比如DNNClassifier是一个预置的Estimator,它是一个通过稠密前向传播的神经网络来进行分类预测的模型。

一个使用Estimator的程序

一个使用了预置Estimator的TensorFlow的程序,一般分为以下四个步骤:
1. 写一个或者多个数据导入的function。比如你写了两个方法,一个导入训练数据,一个导入测试数据。每个方法都返回两个object。
– 一个dictionary,keys是feature的名字,values是Tensors(或者SparseTensors)保留着对应的feature的data。
– 一个Tensor包含着一个或者多个label
比如下边的代码展示了一个input function的概要:

def input_fn(dataset):
   ...  # 操作数据, 提取feature的名字和label
   return feature_dict, label

2.定义feature的columns。 每一个tf.feature_column表示一个feature的名字,类型和任意的预处理。比如下边的一段代码创建了3个feature column,他们持有integer或者floating-point类型的数据。前两个feature简单的定义了feature的名字和类型。第三个feature同时指定了一个lambda表达式,这个表达式会在读入数据时被调用。

population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
                    normalizer_fn='lambda x: x - global_education_mean')

3.实例化预置的Estimator。

estimator = tf.estimator.Estimator.LinearClassifier(
    feature_columns=[population, crime_rate, median_education],
    )

4.调用 training,evalutation或者inference方法。比如所有的Estimator都提供一个train方法来训练模型。

estimator.train(input_fn=my_training_set, steps=2000)

预置Estimator带来的好处

预置的Estimator内置了best practices,提供了一下的好处:
– 默认实现了帮助我们决定计算流图的哪一部分应该在哪里运行,实现了针对单机和集群的策略。
– 提供了很好的event(summary)记录和全局有用的summaries。

如果你不用内置的Estimator,你就需要自己实现上边所说的这行功能。

自定义Estimator


不论是内置合适自定义的Estimator,最重要的就是其中的Model Function,它是一个为建模,评估,预测创建 graphs的方法。预置模型已经实现了这些。如果你要自定义一个Estimator,你就要自己实现。

建议的流程


  1. 假如存在一个合适的预置的Estimator,用它去build你自己的第一个模型,把它作为你的baseline。
  2. 用这个预置的Estimator去测试你的整个流程,包括测试你的数据的完整性和可靠性。
  3. 如果有多个预置的Estimator合适,试着替换不同的Estimator,看哪个好。
  4. 如果可能,自己动手创建你自己的Estimator去改善你的模型。

从Keras 模型创建一个Estimator

你可以把一个已经有的Keras模型转化成一个Estimator,这样可以让Keras模型拥有Estimator的能力,比如分布式计算。通过调用

tf.keras.estimator.model_to_estimator

方法来实现,比如下边的例子:

# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                          loss='categorical_crossentropy',
                          metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)

# Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names  # print out: ['input_1']
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"input_1": train_data},
    y=train_labels,
    num_epochs=1,
    shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)

需要注意的是keras estimator的feature的columns和lables的名字是来自于对应的编译好的keras的模型。

----------------------------------------------------------------------------------

原文:http://www.rethink.fun/index.php/2018/03/07/tensorflow1-2/

猜你喜欢

转载自blog.csdn.net/lcczzu/article/details/91444789