在Tensorflow中把Tensor转换为ndarray时,循环中不断调用run或者eval函数,代码运行越来越慢!

问题

  我有一个这样的需求:我目前有一个已经训练好的encoder模型,它的输出是Tensor类型,我想把它转换成ndarray类型。通过查询资料,我发现可以利用sess.run()Tensor转换为ndarray,于是在我的代码里调用sess.run()成功转换了数据类型。
  但是,我这个数据转换在每一次的循环中都会调用,也就是循环中一直调用sess.run(),于是问题来了,每循环一次,sess.run的用时都比上一次要久,导致后面训练越来越慢。从第一次调用用时0.17s到后面第100次调用时0.27s,而且这才是100次,如果训练10000次,那不知道要等多久,所以这个问题必须解决!

问题原因

  如果在某一个循环里不断建立tensorflow图节点再运行的话,会导致tensorflow运行越来越慢。具体问题请看代码注释,没有注释的代码行可以不用关注,问题代码如下:

import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time

class MyWrapper(gym.ObservationWrapper):
    def __init__(self, env, encoder, latent_dim = 2):
        super().__init__(env)
        self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
        self.observation_space = self._observation_space
        self.encoder = encoder # 这是我已经提前训练好的模型
        tf.InteractiveSession()
        self.sess = tf.get_default_session()
		self.sess.run(tf.global_variables_initializer())
	
    def observation(self, obs):
        obs = np.reshape(obs, (1, -1))
        latent_z_tensor = self.encoder(obs)[2] # 问题就在与这里,这行代码在调用run时,会不断的创建图节点,所以越来越慢
        
        t=time.time() # 测试运行用时
        latent_z_arr = sels.sess.run(latent_z_tensor) # 每次run时,就会把上面的图重新构建一次
        print(time.time()-t) # 测试运行用时

        obs = np.reshape(obs, (-1,))
        latent_z_arr = np.reshape(latent_z_arr, (-1,))

        obs = obs.tolist()
        obs.extend(latent_z_arr.tolist())
        obs = np.array(obs)
        return obs

解决思路

在初始化时,就建立好图结构,使用tf.placeholder占位符表示obs这个变量,具体方案示例如下(可以只关注带有注释的行):

import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time

class MyWrapper(gym.ObservationWrapper):
    def __init__(self, env, encoder, latent_dim = 2):
        super().__init__(env)
        self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
        self.observation_space = self._observation_space
        self.encoder = encoder
        tf.InteractiveSession()
        self.sess = tf.get_default_session()
        self.obs=tf.placeholder(dtype=tf.float32,shape=(1,7)) # 重点在于这两行代码,初始化时先构建好图,先用占位符表示obs,实际运行时只需喂数据obs就好了
        self.latent_z_tensor = self.encoder(self.obs)[2] # 在初始化时构建图
        self.sess.run(tf.global_variables_initializer())

    def observation(self, obs):
        obs = np.reshape(obs, (1, -1))
        t=time.time() # 测试运行用时
        latent_z_arr = self.sess.run(self.latent_z_tensor, feed_dict={
    
    self.obs:obs}) # 这里只需喂数据,不会重新构建图了。
        print(time.time()-t) # 测试运行用时

        obs = np.reshape(obs, (-1,))
        latent_z_arr = np.reshape(latent_z_arr, (-1,))

        obs = obs.tolist()
        obs.extend(latent_z_arr.tolist())
        obs = np.array(obs)
        return obs

现在,数据类型转换完成,代码运行慢也解决了!

猜你喜欢

转载自blog.csdn.net/m0_59019651/article/details/125133422