问题
我有一个这样的需求:我目前有一个已经训练好的encoder模型,它的输出是Tensor
类型,我想把它转换成ndarray
类型。通过查询资料,我发现可以利用sess.run()
把Tensor
转换为ndarray
,于是在我的代码里调用sess.run()
成功转换了数据类型。
但是,我这个数据转换在每一次的循环中都会调用,也就是循环中一直调用sess.run()
,于是问题来了,每循环一次,sess.run
的用时都比上一次要久,导致后面训练越来越慢。从第一次调用用时0.17s到后面第100次调用时0.27s,而且这才是100次,如果训练10000次,那不知道要等多久,所以这个问题必须解决!
问题原因
如果在某一个循环里不断建立tensorflow图节点再运行的话,会导致tensorflow运行越来越慢。具体问题请看代码注释,没有注释的代码行可以不用关注,问题代码如下:
import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time
class MyWrapper(gym.ObservationWrapper):
def __init__(self, env, encoder, latent_dim = 2):
super().__init__(env)
self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
self.observation_space = self._observation_space
self.encoder = encoder # 这是我已经提前训练好的模型
tf.InteractiveSession()
self.sess = tf.get_default_session()
self.sess.run(tf.global_variables_initializer())
def observation(self, obs):
obs = np.reshape(obs, (1, -1))
latent_z_tensor = self.encoder(obs)[2] # 问题就在与这里,这行代码在调用run时,会不断的创建图节点,所以越来越慢
t=time.time() # 测试运行用时
latent_z_arr = sels.sess.run(latent_z_tensor) # 每次run时,就会把上面的图重新构建一次
print(time.time()-t) # 测试运行用时
obs = np.reshape(obs, (-1,))
latent_z_arr = np.reshape(latent_z_arr, (-1,))
obs = obs.tolist()
obs.extend(latent_z_arr.tolist())
obs = np.array(obs)
return obs
解决思路
在初始化时,就建立好图结构,使用tf.placeholder
占位符表示obs
这个变量,具体方案示例如下(可以只关注带有注释的行):
import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time
class MyWrapper(gym.ObservationWrapper):
def __init__(self, env, encoder, latent_dim = 2):
super().__init__(env)
self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
self.observation_space = self._observation_space
self.encoder = encoder
tf.InteractiveSession()
self.sess = tf.get_default_session()
self.obs=tf.placeholder(dtype=tf.float32,shape=(1,7)) # 重点在于这两行代码,初始化时先构建好图,先用占位符表示obs,实际运行时只需喂数据obs就好了
self.latent_z_tensor = self.encoder(self.obs)[2] # 在初始化时构建图
self.sess.run(tf.global_variables_initializer())
def observation(self, obs):
obs = np.reshape(obs, (1, -1))
t=time.time() # 测试运行用时
latent_z_arr = self.sess.run(self.latent_z_tensor, feed_dict={
self.obs:obs}) # 这里只需喂数据,不会重新构建图了。
print(time.time()-t) # 测试运行用时
obs = np.reshape(obs, (-1,))
latent_z_arr = np.reshape(latent_z_arr, (-1,))
obs = obs.tolist()
obs.extend(latent_z_arr.tolist())
obs = np.array(obs)
return obs
现在,数据类型转换完成,代码运行慢也解决了!