mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,17 +10,17 @@ from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
_ALLOWED_OPTIM_DEVICES = [
|
||||
(FusedAdam, torch.device('cuda:0')),
|
||||
(CPUAdam, torch.device('cpu')),
|
||||
(CPUAdam, torch.device('cuda:0')),
|
||||
(HybridAdam, torch.device('cpu')),
|
||||
(HybridAdam, torch.device('cuda:0')),
|
||||
(FusedAdam, torch.device("cuda:0")),
|
||||
(CPUAdam, torch.device("cpu")),
|
||||
(CPUAdam, torch.device("cuda:0")),
|
||||
(HybridAdam, torch.device("cpu")),
|
||||
(HybridAdam, torch.device("cuda:0")),
|
||||
]
|
||||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
|
||||
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
|
||||
]
|
||||
@@ -53,12 +53,17 @@ def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) ->
|
||||
p.data = orig_p
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
|
||||
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
|
||||
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
|
||||
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
|
||||
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
|
||||
@pytest.mark.parametrize("adamw", [False, True])
|
||||
@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES)
|
||||
def test_adam_optim_on_bert(
|
||||
optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]],
|
||||
device: torch.device,
|
||||
adamw: bool,
|
||||
p_dtype: torch.dtype,
|
||||
g_dtype: torch.dtype,
|
||||
) -> None:
|
||||
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values()))
|
||||
torch_model = model_fn().to(device)
|
||||
model = deepcopy(torch_model).to(p_dtype)
|
||||
lr = 1e-3
|
||||
|
Reference in New Issue
Block a user