A operação inadequada de carregamento de dados faz com que o treinamento do tensorflow se torne cada vez mais lento

Ao carregar dados no loop após a Sessão, tf.reshape, tf.transpose e outras operações são realizadas nos dados para transformar os dados do formato narray original em tensor, o que fará com que esta etapa seja adicionada ao gráfico de cálculo, fazendo com que o O gráfico de cálculo se torna cada vez mais complexo. Quanto maior e a razão pela qual o treinamento do Tensorflow se torna cada vez mais lento.

Carregando formato de dados: .mat
método de carregamento: scio.loadmat(file_path) #

Formato de erro:

x_in = tf.compat.vi.placeholde()
...
模型
...
file_path = 'E:\Project_Research\a1.mat'
sess = tf.Session()
sess.run(init)
for i in range(10):
	temp = scio.loadmat(file_path)
	images = temp('images') #a1.mat包含变量为‘images’的矩阵
	images = tf.reshape(images,[-1,11,11,2])
	[] = sess.run([train_op], feed_dict = {x_in: images})

Se você executar a função tf.reshape acima, verá op na variável images durante a depuração , indicando que esta etapa foi adicionada ao gráfico de cálculo, portanto, cada iteração aumentará o tamanho do gráfico de cálculo.

Formato correto:

x_in = tf.compat.vi.placeholde()
x_in = tf.reshape(x_in, [-1,11,11,2])
...
模型
...
file_path = 'E:\Project_Research\a1.mat'
sess = tf.Session()
sess.run(init)
for i in range(10):
	temp = scio.loadmat(file_path)
	images = temp('images') 
	[] = sess.run([train_op], feed_dict = {x_in: images})

Desta forma, o modelo salvo no tensorflow não será alterado.

おすすめ

転載: blog.csdn.net/sanxiaw/article/details/105638905