Tensorflow训练中产生batch

import numpy as np
import time
inputs = np.arange(0,784).reshape(-1,7,7)
targets = np.arange(0,784).reshape(-1,7,7)


# 仅有数据时
def get_batchs(inputs=None, batch_size=None, shuffle=False):
  indices = np.arange(len(inputs))
  if shuffle:
    np.random.shuffle(indices)
  for start_idx in range(0,len(inputs)-batch_size+1, batch_size):
    if shuffle:
      excerpt = indices[start_idx:start_idx + batch_size]
    else:
      excerpt = indices[start_idx:start_idx + batch_size]
      
    yield inputs[excerpt]
    
for batch in get_batchs(inputs,10,True):
  print(batch)
  
  
# 有数据有label时
def get_batch(inputs=None, targets=None, batch_size=None, shuffle=False):
  
  assert len(inputs) == len(targets)
  indices = np.arange(len(inputs))
  if shuffle:
    np.random.shuffle(indices)
  # start_idx为batch_size个数
  for start_idx in range(0, len(inputs) -batch_size + 1, batch_size):
    if shuffle:
      excerpt = indices[start_idx:start_idx + batch_size]
#       print(excerpt)
    else:
      excerpt = indices[start_idx:start_idx + batch_size]
#       print(excerpt)
    yield inputs[excerpt] , targets[excerpt]

# for a,b in get_batch(inputs, targets , 10, False):
#   print(a)
#   print(b)

猜你喜欢

转载自blog.csdn.net/qq_38826019/article/details/83149051