amp(混合精度)训练-torch

(1) 导入模块

from torch.cuda.amp import autocast as autocast, GradScaler

(2) 创建amp梯度缩放器

scaler = GradScaler()

(3) 训练-求loss-反传

if opt['train']['enable_fp16']:
   with autocast():
       # model
       output= model(input)
       # loss
       train_loss = loss(output,label)
       # loss backward
       scaler.scale(train_loss).backward()
       scaler.unscale_(optimizer)
       scaler.step(optimizer)
       scaler.update()

猜你喜欢

转载自blog.csdn.net/mr1217704159/article/details/121351643