人工智能实践:Tensorflow笔记 # 5 神经网络优化:滑动平均


#coding:utf-8
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#待优化的参数
w1 = tf.Variable(0,dtype=tf.float32)
#定义NN的迭代轮数
global_step = tf.Variable(0,trainable=False)
#实例化滑动平均类
MOVING_AVERAGE_DECAY = 0.99
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
#ema.apply([]) -> 括号里的参数是更新列表每次运行sess.run()就对列表中的元素求滑动平均值
ema_op = ema.apply(tf.trainable_variables())

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    print(sess.run([w1,ema.average(w1)]))#ema.average() -> 获取w1的滑动平均值

    sess.run(tf.assign(w1,1))
    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))

    sess.run(tf.assign(global_step,100))
    sess.run(tf.assign(w1,10))
    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))

    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))

    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))

    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))
发布了634 篇原创文章 · 获赞 579 · 访问量 35万+

猜你喜欢

转载自blog.csdn.net/qq_33583069/article/details/103101854