Detailed explanation of keras train_on_batch (detailed explanation of output and input of train_on_batch, detailed explanation of train_on_batch multi-GPU training, custom learning rate adjustment strategy)

Use train_on_batch to finely manage the training process

Most students who use keras use fit() or fit_generator() for model training. These two APIs are very friendly and convenient for students who are new to deep learning, but because they are very deep packages, for those who want to customize the training process It is not so convenient for students (students who switch from torch to keras may prefer to customize the training process), and, for GAN, a model that requires step-by-step training, it cannot be directly trained using fit or fit_generator. Therefore, keras provides the train_on_batch API to update the gradient of a mini-batch data.
The advantages are summarized as follows:

  • More refined custom training process, more accurate collection of loss and metrics
  • Step-by-step training model - implementation of GAN
  • Multi-GPU training and saving models are more convenient
  • More diverse data loading methods, combined with the use of torch dataloader

The following describes the use of train_on_batch

1. Input and output of train_on_batch

1.1 Input

y_pred = Model.train_on_batch(
    x,
    y=None,
    sample_weight=None,
    class_weight=None,
    reset_metrics=True,
    return_dict=False,
)
  • x: model input, a single input is a numpy array, and multiple inputs are a list of numpy arrays
  • y: label, single output model is a numpy array, multi-output model is a list of numpy arrays
  • sample_weight: The weight corresponding to each sample in the mini-batch, the shape is (batch_size)
  • class_weight: category weight, acting on the loss function, adding weight to the loss of each category, mainly used in the case of unbalanced categories, the shape is (num_classes)
  • reset_metrics: The default is True, the returned metrics are only for this mini-batch, if False, the metrics will be accumulated across batches
  • return_dict: default False, y_pred is a list , if True then y_pred is a dictionary

1.2 Output

  • Single output model, 1 loss, no metrics, train_on_batch returns a scalar representing the loss of this mini-batch, for example
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'])
y_pred = model.train_on_batch(x=image,y=label)
# y_pred 为标量
  • Single output model, with 1 loss, n metrics, train_on_batch returns a list , the length of the list is 1+n , for example
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'], metrics=['accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# len(y_pred) == 2, y_pred[0]为loss, y_pred[1]为accuracy
  • Multi-output model, n loss, m metrics, train_on_batch returns a list , the length of the list is 1+n+m*n , for example
model = keras.models.Model(inputs=inputs, outputs=[output1, output2])
model.compile(Adam, 
			  loss=['binary_crossentropy', 'binary_crossentropy'], 
			  metrics=['accuracy', 'accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# 查看model.metrics_names来了解返回列表中每个值的含义

2. train_on_batch multi-GPU training model

2.1 Multi-GPU model initialization, loading weights, model compilation, model saving

Notice! Operate on para_model during training, and operate on model when saving

import tensorflow as tf
import keras
import os

# 初始化GPU的使用个数
gpu = "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
gpu_num = len(gpu.split(','))

# model初始化
with tf.device('/cpu:0'):# 使用多GPU时,先在CPU上初始化模型
	model = YourModel(input_size, num_classes)
	model.load_weights('*.h5') # 如果有权重需要加载,在这里实现
para_model = keras.utils.multi_gpu_model(model, gpus=gpu_num) # 在GPU上初始化多GPU模型
para_model.compile(optimizer, loss=[...], metrics=[...]) # 编译多GPU模型
	
# 训练和验证,对 para_model 使用 train_on_batch
def train():
	para_model.train_on_batch(...)
		
def evaluate():
	para_model.test_on_batch(...)
		
# 保存模型,注意!训练时对 para_model 操作,保存时对 model 做操作
# 不要使用 para_model.save() 或者 para_model.save_weights(),否则加载时会出问题
model.save('*.h5')
model.save_weights('*.h5')

3. Custom Learning Rate Adjustment Strategy

Since callbacks cannot be used, we use keras.backend.get_value() and keras.backend.set_value() to get and set the current learning rate. For example, let’s implement the simplest step-down learning rate. For every 10 epochs, the learning rate drops by 0.1 times

import keras.backend as K

for epoch in range(100):
	train_one_epoch()
	evaluate()
	# 每10个epoch,lr缩小0.1倍
	if epoch%10==0 and epoch!=0:
		lr = K.get_value(model.optimizer.lr) # 获取当前学习率
		lr = lr * 0.1 # 学习率缩小0.1倍
		K.set_value(model.optimizer.lr, lr) # 设置学习率

4. Combination of keras and torch

Torch's dataloader is the best data loading method I have used so far. Part of the reason for using train_on_batch is because I can use torch dataloader to load data, and then use train_on_batch to train the model. By reasonably controlling the use of cpu workers The size of the number and batch_size maximizes the training efficiency of the model

4.1 dataloader+train_on_batch training keras model pipeline

# 定义 torch dataset
class Dataset(torch.utils.data.Dataset):
	def __init__(self, root_list, transforms=None):
		self.root_list = root_list
		self.transforms = transforms
		
	def __getitem__(self, idx):
		# 假设是图像分类任务
		image = ... # 读取单张图像
		label = ... # 读取标签
		if self.transforms is not None:
			image = self.transforms(image)
		return image, label # shape: (H,W,3), salar
		
	def __len__(self):
		return len(self.root_list)
		
# 自定义 collate_fn 使 dataloader 返回 numpy array
def collate_fn(batch):
	# 这里的 batch 是 tuple 列表,[(image, label),(image, label),...]
	image, label = zip(*batch)
	image = np.asarray(image) # (batch_size, H, W, 3)
	label = np.asarray(label) # (batch_size)
	return image, label # 如果 datast 返回的图像是 ndarray,这样loader返回的也是 ndarray
	
# 定义dataset
train_dataset = Dataset(train_list)
valid_dataset = Dataset(valid_list)
test_dataset = Dataset(test_list)

# 定义 dataloader, 如果不使用自定义 collate_fn,
# 从 loader 取出的默认是 torch Tensor,需要做一个 .numpy()的转换
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

# 定义 train,evaluate,test
def train():
	for i,(inputs, label) in enumerate(train_loader):
		# 如果 inputs 和 label 是 torch Tensor
		# 请用 inputs = inputs.numpy() 和 label = label.numpy() 转成 ndarray
		y_pred = model.train_on_batch(inputs, label)
		
def evaluate():
	for i,(inputs, label) in enumerate(valid_loader):
		# 如果 inputs 和 label 是 Tensor,同上
		y_pred = model.test_on_batch(inputs, label)
		
def test():
	for i,(inputs, label) in enumerate(test_loader):
		# 如果 inputs 和 label 是 Tensor,同上
		y_pred = model.test_on_batch(inputs, label)
		
def run():
	for epoch in num_epoch:
		train()
		evaluate()
	test()
	
if __name__ == "__main__":
	run()

Summarize

There are also some places where train_on_batch is used, such as GAN training, which will not be introduced here. For details, you can search on github, such as keras-dcgan .

reference

keras official api: train_on_batch

Guess you like

Origin blog.csdn.net/baoxin1100/article/details/107917633