""" 批量训练 """ import os import numpy as np import tensorflow as tf import matplotlib.pyplot as plt os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' s = tf.Session() # 声明批量训练的数据量的大小 batch_size = 20 # 声明模型的数据、占位符和变量 x_vals = np.random.normal(1, 0.1, 100) y_vals = np.repeat(10., 100) x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32) # 可显式地设置维度为 20 y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) # 也可设置为 None A = tf.Variable(tf.random_normal(shape=[1, 1])) B = tf.Variable(tf.random_normal(shape=[1, 1])) # 初始化变量 init = tf.global_variables_initializer() s.run(init) # 在计算图中增加矩阵乘法操作 my_output = tf.matmul(x_data, A) # 批量训练时,损失函数是每个数据点 L2 损失的平均值 loss = tf.reduce_mean(tf.square(my_output - y_target)) # 声明优化器 my_opt = tf.train.GradientDescentOptimizer(0.02) train_step = my_opt.minimize(loss) # 在训练过程中,通过循环迭代优化模型算法 loss_batch = [] # 初始化一个列表,每隔 5 次迭代保存损失函数 for i in range(100): rand_index = np.random.choice(100, size=batch_size) rand_x = np.transpose([x_vals[rand_index]]) rand_y = np.transpose([y_vals[rand_index]]) s.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y}) if (i+1) % 5 == 0: print('Step # ' + str(i + 1) + ' A = ' + str(s.run(A))) temp_loss = s.run(loss, feed_dict={x_data: rand_x, y_target: rand_y}) print('Loss = ' + str(temp_loss)) loss_batch.append(temp_loss) # 为防止上一节代码中变量 A 的值的影响, 在进行随机训练时,需要将变量 A 重新初始化 init1 = tf.global_variables_initializer() s.run(init1) # 随机损失代码 loss_stochastic = [] for j in range(100): rand_index = np.random.choice(100) rand_x = [[x_vals[rand_index]]] rand_y = [[y_vals[rand_index]]] s.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y}) if (j + 1) % 5 == 0: print('Step # ' + str(j + 1) + ' A = ' + str(s.run(A))) temp_loss = s.run(loss, feed_dict={x_data: rand_x, y_target: rand_y}) print('Loss = ' + str(temp_loss)) loss_stochastic.append(temp_loss) # 绘制回归算法的随机训练损失和批量训练损失 plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss') plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size = 20') plt.legend(loc='upper right', prop={'size': 11}) plt.show()
TensorFlow 机器学实战指南示例代码之 TensorFlow 实现随机训练和批量训练
猜你喜欢
转载自blog.csdn.net/lingtianyulong/article/details/79279688
今日推荐
周排行