自定义模型

custom model

we need to use tf.estimator.Estimator. tf.estimator.LinearRegressor is actually a sub-class of tf.estimator.Estimator. Instead of sub-classing Estimator, we simply provide Estimator a function model_fn that tells tf.estimator how it can evaluate predictions, training steps, and loss. The code is as follows:

import numpy as np
import tensorflow as tf
def model_fn(features, labels, mode):
    W = tf.get_variable("W", [1], dtype=tf.float64)
    b = tf.get_variable("b", [1], dtype=tf.float64)
    y = W * features['x'] + b

    loss = tf.reduce_sum(tf.square(y - labels))

    global_step = tf.train.get_global_step()
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    train = tf.group(optimizer.minimize(loss),
                    tf.assign_add(global_step, 1))

    return tf.estimator.EstimatorSpec(mode=mode,predictions = y,loss = loss,train_op = train)
estimator = tf.estimator.Estimator(model_fn=model_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpmwNhao
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_model_dir': '/tmp/tmpmwNhao', '_save_summary_steps': 100}
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7., 0.])

input_fn = tf.estimator.inputs.numpy_input_fn({"x":x_train}, y_train, batch_size = 4, num_epochs=None, shuffle=True)

train_input_fn = tf.estimator.inputs.numpy_input_fn({"x":x_train}, y_train, batch_size=4, num_epochs=1000, shuffle=False)

eval_input_fn = tf.estimator.inputs.numpy_input_fn({"x":x_eval}, y_eval, batch_size=4, num_epochs=1000, shuffle=False)
#train
estimator.train(input_fn=input_fn, steps=1000)
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpmwNhao/model.ckpt.
INFO:tensorflow:loss = 10.8243056552, step = 1
INFO:tensorflow:global_step/sec: 254.746
INFO:tensorflow:loss = 0.0224178061166, step = 101 (0.394 sec)
INFO:tensorflow:global_step/sec: 263.384
INFO:tensorflow:loss = 0.00172416346232, step = 201 (0.380 sec)
INFO:tensorflow:global_step/sec: 260.919
INFO:tensorflow:loss = 0.000352547614769, step = 301 (0.383 sec)
INFO:tensorflow:global_step/sec: 745.272
INFO:tensorflow:loss = 3.27611546304e-05, step = 401 (0.134 sec)
INFO:tensorflow:global_step/sec: 268.733
INFO:tensorflow:loss = 6.18242476735e-07, step = 501 (0.373 sec)
INFO:tensorflow:global_step/sec: 257.906
INFO:tensorflow:loss = 2.41339410526e-07, step = 601 (0.388 sec)
INFO:tensorflow:global_step/sec: 259.224
INFO:tensorflow:loss = 2.08072196581e-08, step = 701 (0.386 sec)
INFO:tensorflow:global_step/sec: 261.673
INFO:tensorflow:loss = 9.51939277022e-10, step = 801 (0.382 sec)
INFO:tensorflow:global_step/sec: 263.729
INFO:tensorflow:loss = 1.07657869207e-10, step = 901 (0.379 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/tmpmwNhao/model.ckpt.
INFO:tensorflow:Loss for final step: 1.42408389095e-11.





<tensorflow.python.estimator.estimator.Estimator at 0x7f9d1e503c50>
train_metrics = estimator.evaluate(input_fn=train_input_fn)
eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("train metrics: %r"% train_metrics)
print("eval metrics: %r"% eval_metrics)
INFO:tensorflow:Starting evaluation at 2017-11-19-21:17:59
INFO:tensorflow:Restoring parameters from /tmp/tmpmwNhao/model.ckpt-1000
INFO:tensorflow:Finished evaluation at 2017-11-19-21:18:03
INFO:tensorflow:Saving dict for global step 1000: global_step = 1000, loss = 1.01341e-11
INFO:tensorflow:Starting evaluation at 2017-11-19-21:18:03
INFO:tensorflow:Restoring parameters from /tmp/tmpmwNhao/model.ckpt-1000
INFO:tensorflow:Finished evaluation at 2017-11-19-21:18:07
INFO:tensorflow:Saving dict for global step 1000: global_step = 1000, loss = 0.0101007
train metrics: {'loss': 1.0134118e-11, 'global_step': 1000}
eval metrics: {'loss': 0.010100666, 'global_step': 1000}

猜你喜欢

转载自blog.csdn.net/SarKerson/article/details/78625712