mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
86
tests/test_optimizer/test_adam_optim.py
Normal file
86
tests/test_optimizer/test_adam_optim.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from copy import deepcopy
|
||||
from typing import Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam, AdamW
|
||||
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
_ALLOWED_OPTIM_DEVICES = [
|
||||
(FusedAdam, torch.device('cuda:0')),
|
||||
(CPUAdam, torch.device('cpu')),
|
||||
(CPUAdam, torch.device('cuda:0')),
|
||||
(HybridAdam, torch.device('cpu')),
|
||||
(HybridAdam, torch.device('cuda:0')),
|
||||
]
|
||||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
|
||||
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
|
||||
]
|
||||
|
||||
N_STEPS = 3
|
||||
|
||||
|
||||
def setup_param_groups(bert_model: nn.Module) -> list:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
|
||||
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
torch_p.grad = torch.rand_like(torch_p)
|
||||
# avoid inconsistent grad and param dtype error
|
||||
orig_p = p.data
|
||||
p.data = torch_p.grad.clone().to(g_dtype)
|
||||
p.grad = p.data
|
||||
p.data = orig_p
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
|
||||
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
|
||||
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
|
||||
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
|
||||
torch_model = model_fn().to(device)
|
||||
model = deepcopy(torch_model).to(p_dtype)
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
torch_optim_cls = AdamW if adamw else Adam
|
||||
torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
|
||||
optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
|
||||
|
||||
rtol, atol = 1e-5, 1e-5
|
||||
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||
rtol, atol = 2e-3, 2e-3
|
||||
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
||||
rtol, atol = 4e-3, 4e-3
|
||||
|
||||
for _ in range(N_STEPS):
|
||||
set_grad(model, torch_model, g_dtype)
|
||||
torch_optim.step()
|
||||
optim.step()
|
||||
torch_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
# if overflow, the weight won't be updated. so there will be no nan in p
|
||||
assert not torch.isnan(p).any()
|
||||
assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)
|
Reference in New Issue
Block a user