(1) Import module
from torch.cuda.amp import autocast as autocast, GradScaler
(2) Create amp gradient scaler
scaler = GradScaler()
(3) Training-seeking loss-reverse transmission
if opt['train']['enable_fp16']:
with autocast():
# model
output= model(input)
# loss
train_loss = loss(output,label)
# loss backward
scaler.scale(train_loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()