amp (mixed precision) training-torch

(1) Import module

from torch.cuda.amp import autocast as autocast, GradScaler

(2) Create amp gradient scaler

scaler = GradScaler()

(3) Training-seeking loss-reverse transmission

if opt['train']['enable_fp16']:
   with autocast():
       # model
       output= model(input)
       # loss
       train_loss = loss(output,label)
       # loss backward
       scaler.scale(train_loss).backward()
       scaler.unscale_(optimizer)
       scaler.step(optimizer)
       scaler.update()

Guess you like

Origin blog.csdn.net/mr1217704159/article/details/121351643