nnUnet代码解读--优化策略

优化策略

nnUNetTrainer

def initialize_optimizer_and_scheduler(self):
    assert self.network is not None, "self.initialize_network must be called first"
    self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True)
    self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
                                                       patience=self.lr_scheduler_patience,
                                                       verbose=True, threshold=self.lr_scheduler_eps, threshold_mode="abs")
  • self.initial_lr = 0.01
  • self.lr_scheduler_eps = 0.001
  • self.weight_decay = 3e-5

学习率衰减到lr_scheduler_eps 停止衰减,其实就刚开始两三个epoch学习率在0.01和0.001之间,后面的学习率就都是0.001了。

nnUNetTrainerV2

def initialize_optimizer_and_scheduler(self):
    assert self.network is not None, "self.initialize_network must be called first"
    self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                     momentum=0.99, nesterov=True)
    self.lr_scheduler = None

V2使用的是学习率指数衰减,每个epoch更新学习率

	def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    	return initial_lr * (1 - epoch / max_epochs)**exponent
    
    def maybe_update_lr(self, epoch=None):
        """
        if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1

        (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
        herefore we need to do +1 here)

        :param epoch:
        :return:
        """
        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch
        self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
        self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))

自定义余弦衰减

分享一下我常用的学习率衰减策略

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0.):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule
  • 从一个很小的学习率预热到基线学习率(base_learning_rate)
  • 在每个iteration而不是epoch中更新学习率

学习率更新的代码我放在下面作为参考

for it in range(self.num_batches_per_epoch):
    it = self.num_batches_per_epoch * self.epoch + it
    param_group = self.optimizer.param_groups[0]
    param_group['lr'] = self.lr_scheduler[it]

对比实验

假设初始学习率都为0.01,最大迭代次数为100,下图是三种不同的学习率衰减曲线,注意这里的余弦衰减的横坐标是iteration

在这里插入图片描述

实验代码我也放在这里了,直接运行就好

import math
import numpy as np
import matplotlib.pyplot as plt


def schedule1(epochs,lr,factor,threshold):
    lr_list = [max(lr * math.pow(factor, x),threshold) for x in range(epochs)]
    return lr_list


def schedule2(epochs,lr,exponent):
    lr_list = [lr * (1 - x / epochs) ** exponent for x in range(epochs)]
    return lr_list


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0.):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


max_epochs = 100
num_batches_per_epoch = 250
initial_lr = 0.01
factor = 0.2
lr_scheduler_eps = 0.001
exponent = 0.9
warmup_epochs = 10

lr_list1 = schedule1(max_epochs,initial_lr,factor,threshold=lr_scheduler_eps)
lr_list2 = schedule2(max_epochs,initial_lr,exponent)
lr_list3 = cosine_scheduler(initial_lr,lr_scheduler_eps,max_epochs,num_batches_per_epoch,warmup_epochs,start_warmup_value=5e-4)
plt.subplot(131),plt.plot(lr_list1),plt.title("ReduceLROnPlateau")
plt.subplot(132),plt.plot(lr_list2),plt.title("PolyScheduler")
plt.subplot(133),plt.plot(lr_list3),plt.title("CosineScheduler")
plt.show()

码字不易,有用的话还请点个赞,后面会继续更新nnUnet的相关内容。

猜你喜欢

转载自blog.csdn.net/weixin_44858814/article/details/124572365