(1) 导入模块
from torch.cuda.amp import autocast as autocast, GradScaler
(2) 创建amp梯度缩放器
scaler = GradScaler()
(3) 训练-求loss-反传
if opt['train']['enable_fp16']:
with autocast():
# model
output= model(input)
# loss
train_loss = loss(output,label)
# loss backward
scaler.scale(train_loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()