[Tensorflow] 网络局部restore 以及 网络局部训练

网络架构如下:

参数为a1,a2,b1,b2,

网络输出:y=a1*a2*x+b1+b2

目标函数:y=x

一、网络局部参数restore

应用场景:网络架构修改,但是部分参数需要重新利用;

设置方法:将var_list参数传给tf.train.Saver即可只save/restore var_list里的参数

如何使用:

(1)保存save:a1,a2,b1,b2分别为10,20,30,40

import tensorflow as tf
import random
#目标函数y=x
#也就是网络收敛时:a1*a2=1,b1+b2=0
x=tf.placeholder(tf.float32,[1])
with tf.variable_scope("AB1"):
  a1=tf.Variable(tf.constant([10],dtype=tf.float32),name="A1")  
  b1=tf.Variable(tf.constant([30],dtype=tf.float32),name="B1")
with tf.variable_scope("AB2"):
  a2=tf.Variable(tf.constant([20],dtype=tf.float32),name="A2")  
  b2=tf.Variable(tf.constant([40],dtype=tf.float32),name="B2")
y=a1*a2*x+b1+b2
#
_y=tf.placeholder(tf.float32,[1])
loss=tf.square(y-_y)
sess=tf.Session()
sess.run(tf.global_variables_initializer())  

var_list_ab1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
var_list_ab2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB2')
saver=tf.train.Saver()
saver.save(sess,"./10203040")

(2)恢复restore:

import tensorflow as tf
import random
#目标函数y=x
#也就是网络收敛时:a1*a2=1,b1+b2=0
x=tf.placeholder(tf.float32,[1])
with tf.variable_scope("AB1"):
  a1=tf.Variable(tf.constant([1],dtype=tf.float32),name="A1")  
  b1=tf.Variable(tf.constant([3],dtype=tf.float32),name="B1")
with tf.variable_scope("AB2"):
  a2=tf.Variable(tf.constant([2],dtype=tf.float32),name="A2")  
  b2=tf.Variable(tf.constant([4],dtype=tf.float32),name="B2")
y=a1*a2*x+b1+b2
#
_y=tf.placeholder(tf.float32,[1])
loss=tf.square(y-_y)
sess=tf.Session()
sess.run(tf.global_variables_initializer())  

var_list_ab1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
var_list_ab2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB2')
saver=tf.train.Saver(var_list_ab1)
saver.restore(sess,"./10203040")

print(sess.run([a1,a2,b1,b2])) 

二、只训练局部参数,如只训练a1,b1;而a2,b2保持不变

扫描二维码关注公众号,回复: 1730933 查看本文章

设置方法:在optimizer传入要训练的参数列表,即var_list参数:

方法一(基于variable_scope获取var_list):

var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1') 
train=tf.train.GradientDescentOptimizer(1e-1).minimize(loss,var_list=var_list) 

方法二(比较笨拙的方式):

import tensorflow as tf
import random
#目标函数y=x
#也就是网络收敛时:a1*a2=1,b1+b2=0
x=tf.placeholder(tf.float32,[1])
with tf.variable_scope("AB1"):
  a1=tf.Variable(tf.constant([1],dtype=tf.float32),name="A1")  
  b1=tf.Variable(tf.constant([3],dtype=tf.float32),name="B1")
with tf.variable_scope("AB2"):
  a2=tf.Variable(tf.constant([2],dtype=tf.float32),name="A2")  
  b2=tf.Variable(tf.constant([4],dtype=tf.float32),name="B2")
y=a1*a2*x+b1+b2
#
_y=tf.placeholder(tf.float32,[1])
loss=tf.square(y-_y)
sess=tf.Session()
sess.run(tf.global_variables_initializer())  

#var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
var_list = [a1,b1]                                      # !!!!!!!!!在这里,手工添加var_list
train=tf.train.GradientDescentOptimizer(1e-1).minimize(loss,var_list=var_list) 
while True:  
  input=[random.randint(0,100)*0.01]  #不乘以0.0001,则网络无法收敛
  label=[input[0]]  
  _,a1v,a2v,b1v,b2v,lossv=sess.run([train,a1,a2,b1,b2,loss],feed_dict={x:input,_y:label}) 
  if (lossv<1e-10):
    break
  print("train data=%s %s" %(input,label))
  print("a=%s %s\n b=%s %s\n loss=%s" %(a1v,a2v,b1v,b2v,lossv))

猜你喜欢

转载自blog.csdn.net/vcvycy/article/details/79520389
今日推荐