ColossalAI/tests/test_optimizer/test_adam_optim.py
Hongxin Liu ae02d4e4f7
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
2023-06-05 15:58:31 +08:00

87 lines
3.3 KiB
Python

from copy import deepcopy
from typing import Type, Union
import pytest
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from tests.kit.model_zoo import model_zoo
_ALLOWED_OPTIM_DEVICES = [
(FusedAdam, torch.device('cuda:0')),
(CPUAdam, torch.device('cpu')),
(CPUAdam, torch.device('cuda:0')),
(HybridAdam, torch.device('cpu')),
(HybridAdam, torch.device('cuda:0')),
]
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]
N_STEPS = 3
def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
torch_p.grad = torch.rand_like(torch_p)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
torch_model = model_fn().to(device)
model = deepcopy(torch_model).to(p_dtype)
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_optim_cls = AdamW if adamw else Adam
torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
rtol, atol = 1e-5, 1e-5
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 2e-3, 2e-3
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 4e-3, 4e-3
for _ in range(N_STEPS):
set_grad(model, torch_model, g_dtype)
torch_optim.step()
optim.step()
torch_optim.zero_grad()
optim.zero_grad()
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)