#coding:utf-8
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#待优化的参数
w1 = tf.Variable(0,dtype=tf.float32)
#定义NN的迭代轮数
global_step = tf.Variable(0,trainable=False)
#实例化滑动平均类
MOVING_AVERAGE_DECAY = 0.99
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
#ema.apply([]) -> 括号里的参数是更新列表每次运行sess.run()就对列表中的元素求滑动平均值
ema_op = ema.apply(tf.trainable_variables())
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run([w1,ema.average(w1)]))#ema.average() -> 获取w1的滑动平均值
sess.run(tf.assign(w1,1))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))
sess.run(tf.assign(global_step,100))
sess.run(tf.assign(w1,10))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))