memo: PyTorch | Automatic Mixed Precision

Nvidia apex

An example project using it is AIM.

torch amp

An example: Automatic Mixed Precision recipe

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

for epoch in range(epochs):
    for input, target in zip(data, targets):

        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance

【pytorch distributed】amp 原理,automatic mixed precision 自动混合精度