为了实现mini-batch,需要一种节点,在每次迭代使用一个新的batch,可以用placeholder node实现这个功能。
>>>A = placeholder(tf.float32, shape=(None, 3))
>>>B = A + 5
#这里shape=(None, 3)限制了向A节点feed data的维数
使用eval()的feed_dict参数传入A的值并计算相应的B.eval():
>>>with tf.Session() as less:
... B_val_1 = B.eval(feed_dict={A: [[1, 2, 3]]})
... B_val_2 = B.eval(feed_fict={A: [[4, 5, 6], [7, 8, 9]]})
>>>print(B_val_1)
[[6. 7. 8.]]
>>>print(B_val_2)
[[ 9. 10. 11.]
[12. 13. 14.]]
mini-batch的实现:
X = tf.placeholder(tf.float32, shape = (None, n + 1)), name = 'X')
y = tf.placeholder(tf.float32, shape = (None, 1)), name = 'y')
batch_size = 100 #设置batch大小
n_batches = int(np.ceil(m/batch_size))
def fetch_batch(epoch, batch_index, batch_size):
[...] #读入数据
return X_batch, y_batch
with tf.Session() as sess:
sess.run(init)
for epoch in range(n_epochs):
for batch_index in range(n_batches):
X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
best_theta = theta.eval()