tensorflow estimator 实践

    本文以mnist数据集为例。estimator通常是和tf的dataset一起使用,故先制作tfrecord文件,在使用estimator进行测试。

文章结构:

1.文件目录

2. 制作tfrecord文件

3.使用estimator训练模型

4.tf.estimator.Estimator()参数介绍:


文件目录: 

     

   data目录下存放mnist数据集,并且tfrecord文件也将存放在data目录下。

   result目录下保存训练的模型

   make_record.py: 制作tfreocrd文件

   mnist.py: 使用estimator训练模型、test结果

制作tfrecord文件:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("./data/", one_hot=True)
def tf_record(data, labels, path):
  #data是mnist图片
  #labels是图片的标签
  # path是tfrecord保存的路劲
  writer=tf.python_io.TFRecordWriter(path)
  for example, label in zip(data, labels):
    tf_example=tf.train.Example(
                  features=tf.train.Features(
                    feature={
                              "image":tf.train.Feature(float_list=tf.train.FloatList(value=list(example))),
                      "label":tf.train.Feature(float_list=tf.train.FloatList(value=list(label)))
                    }
                  )
              )
    writer.write(tf_example.SerializeToString())
  writer.close()

使用estimator训练模型:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from make_record import tf_record
flags = tf.flags
FLAGS = flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)

# 模型部分,使用2个卷积+全连接进行建模
def create_model(images):
  def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)
  def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)
  def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  def max_poo_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  image = tf.reshape(images, [-1, 28, 28, 1])
  W_conv1 = weight_variable([5, 5, 1, 32])
  b_conv1 = bias_variable([32])
  h_conv1 = tf.nn.relu(conv2d(image, W_conv1) + b_conv1)
  h_pool1 = max_poo_2x2(h_conv1)

  W_conv2 = weight_variable([5, 5, 32, 64])
  b_conv2 = bias_variable([64])
  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  h_pool2 = max_poo_2x2(h_conv2)

  W_fc1 = weight_variable([7 * 7 * 64, 1024])
  b_fc1 = bias_variable([1024])

  h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])


  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

  W_fc2 = weight_variable([1024, 10])
  b_fc2 = bias_variable([10])

  prediction = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
  return prediction

# 将词函数传递给estimator的model_fn参数
# 关于estimator中的各个参数作用,见后 
def model_fn(features,mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  prediction = create_model(images=images)
  if is_training:
    labels = features['label']
    loss = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(prediction)))
    train_op = tf.train.AdamOptimizer(1e-4).minimize(loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(
      mode=tf.estimator.ModeKeys.TRAIN,
      loss=loss,
      train_op=train_op)
  if mode == tf.estimator.ModeKeys.PREDICT:
    labels = features['label']
    return tf.estimator.EstimatorSpec(
      mode=tf.estimator.ModeKeys.PREDICT,
      predictions={
        "ground_truth":tf.argmax(labels, axis=1), "prediction":tf.argmax(prediction, axis=1)
      })


# 此函数返回一个函数, 其返回train的dataset 
def input_fn_builder(input_file):
  def input_fn():
    def pase(record):
      keys_to_features = {
        "image": tf.FixedLenFeature([784], tf.float32),
        "label": tf.FixedLenFeature([10], tf.float32)
      }
      parsed = tf.parse_single_example(record, keys_to_features)
      return parsed
    d = tf.data.TFRecordDataset(input_file)
    return d.map(pase).shuffle(buffer_size=100).batch(32).repeat(10)
  return input_fn
# 此函数返回一个函数, 其返回test的dataset 
def input_fn_builder_test(input_file):
  def input_fn():
    def pase(record):
      keys_to_features = {
        "image": tf.FixedLenFeature([784], tf.float32),
        "label": tf.FixedLenFeature([10], tf.float32)
      }
      parsed = tf.parse_single_example(record, keys_to_features)
      return parsed

    d = tf.data.TFRecordDataset(input_file)
    return d.map(pase).shuffle(buffer_size=100).batch(32).repeat(1)
  return input_fn



def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)

  tf.logging.info("******read mnist dataset******")
  mnist = input_data.read_data_sets('./data', one_hot=True)
  tf.logging.info("******make train record******")
  train = mnist.train.images
  labels = mnist.train.labels
  tf_record(train, labels, "./data/train_record")



  tf.logging.info("******make test record******")
  test = mnist.test.images
  labels = mnist.test.labels
  tf_record(test, labels, "./data/test_record")


  session_config = tf.ConfigProto(log_device_placement=True)
  session_config.gpu_options.per_process_gpu_memory_fraction = 0.5

  # 运行配置,如空置显存,训练多少步(这里是2000)保存一次。
  run_config = tf.estimator.RunConfig(session_config=session_config,,save_checkpoints_steps=2000)
  # model_dir是模型的保存路径
  #  model_fn是模型函数
  # config是运行配置
  estimator = tf.estimator.Estimator(
      model_dir="./result",
      model_fn=model_fn,
      config=run_config,

      )
  input_fn = input_fn_builder("./data/train_record")
  input_fn_test = input_fn_builder_test("./data/test_record")
  # 训练
  estimator.train(input_fn=input_fn)
  all_results = []
  # 测试
  for result in estimator.predict(input_fn_test, yield_single_examples=True):
      if len(all_results) % 1000 == 0:
        tf.logging.info("Processing example: %d" % (len(all_results)))
      label = int(result["ground_truth"])
      prediction = int(result['prediction'])
      all_results.append([label, prediction])
  acc = 0.
  for idx, result in enumerate(all_results):
    if idx == 0:
      print(type(result))
      print(type(all_results))
      print(len(all_results))
      print(len(result))
    if result[0] == result[1]:
      acc = acc + 1
  print(acc/len(all_results))

if __name__ == "__main__":
  tf.app.run()

结果:

模型:

tf.estimator.Estimator()参数介绍: 

这个只能自己看了

class Estimator(object):
  """Estimator class to train and evaluate TensorFlow models.

  The `Estimator` object wraps a model which is specified by a `model_fn`,
  which, given inputs and a number of other parameters, returns the ops
  necessary to perform training, evaluation, or predictions.

  All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a
  subdirectory thereof. If `model_dir` is not set, a temporary directory is
  used.
  
  The `config` argument can be passed `RunConfig` object containing information
  about the execution environment. It is passed on to the `model_fn`, if the
  `model_fn` has a parameter named "config" (and input functions in the same
  manner). If the `config` parameter is not passed, it is instantiated by the
  `Estimator`. Not passing config means that defaults useful for local execution
  are used. `Estimator` makes config available to the model (for instance, to
  allow specialization based on the number of workers available), and also uses
  some of its fields to control internals, especially regarding checkpointing.

  The `params` argument contains hyperparameters. It is passed to the
  `model_fn`, if the `model_fn` has a parameter named "params", and to the input
  functions in the same manner. `Estimator` only passes params along, it does
  not inspect it. The structure of `params` is therefore entirely up to the
  developer.

  None of `Estimator`'s methods can be overridden in subclasses (its
  constructor enforces this). Subclasses should use `model_fn` to configure
  the base class, and may add methods implementing specialized functionality.

  @compatibility(eager)
  Estimators are not compatible with eager execution.
  @end_compatibility
  """

  def __init__(self, model_fn, model_dir=None, config=None, params=None,
               warm_start_from=None):
    """Constructs an `Estimator` instance.

    See @{$estimators} for more information. To warm-start an `Estimator`:

    ```python
    estimator = tf.estimator.DNNClassifier(
        feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
        hidden_units=[1024, 512, 256],
        warm_start_from="/path/to/checkpoint/dir")
    ```

    For more details on warm-start configuration, see
    @{tf.estimator.WarmStartSettings$WarmStartSettings}.

    Args:
      model_fn: Model function. Follows the signature:

        * Args:

          * `features`: This is the first item returned from the `input_fn`
                 passed to `train`, `evaluate`, and `predict`. This should be a
                 single `Tensor` or `dict` of same.
          * `labels`: This is the second item returned from the `input_fn`
                 passed to `train`, `evaluate`, and `predict`. This should be a
                 single `Tensor` or `dict` of same (for multi-head models). If
                 mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
                 the `model_fn`'s signature does not accept `mode`, the
                 `model_fn` must still be able to handle `labels=None`.
          * `mode`: Optional. Specifies if this training, evaluation or
                 prediction. See `ModeKeys`.
          * `params`: Optional `dict` of hyperparameters.  Will receive what
                 is passed to Estimator in `params` parameter. This allows
                 to configure Estimators from hyper parameter tuning.
          * `config`: Optional configuration object. Will receive what is passed
                 to Estimator in `config` parameter, or the default `config`.
                 Allows updating things in your `model_fn` based on
                 configuration such as `num_ps_replicas`, or `model_dir`.

        * Returns:
          `EstimatorSpec`

      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `PathLike` object, the
        path will be resolved. If `None`, the model_dir in `config` will be used
        if set. If both are set, they must be same. If both are `None`, a
        temporary directory will be used.
      config: Configuration object.
      params: `dict` of hyper parameters that will be passed into `model_fn`.
              Keys are names of parameters, values are basic python types.
      warm_start_from: Optional string filepath to a checkpoint to warm-start
                       from, or a `tf.estimator.WarmStartSettings` object to
                       fully configure warm-starting.  If the string filepath is
                       provided instead of a `WarmStartSettings`, then all
                       variables are warm-started, and it is assumed that
                       vocabularies and Tensor names are unchanged.

    Raises:
      RuntimeError: If eager execution is enabled.
      ValueError: parameters of `model_fn` don't match `params`.
      ValueError: if this is called via a subclass and if that class overrides
        a member of `Estimator`.
    """

若有问题欢迎评论指出!!转载请标明地址:https://mp.csdn.net/postedit/86678395

猜你喜欢

转载自blog.csdn.net/biubiubiu888/article/details/86678395