mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
This commit is contained in:
@@ -8,6 +8,7 @@ from torch.optim import Adam, AdamW
|
||||
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_optimizer._utils import force_assign_grad, setup_param_groups
|
||||
|
||||
_ALLOWED_OPTIM_DEVICES = [
|
||||
(FusedAdam, torch.device("cuda:0")),
|
||||
@@ -26,29 +27,11 @@ _ALLOWED_P_G_TYPES = [
|
||||
N_STEPS = 3
|
||||
|
||||
|
||||
def setup_param_groups(bert_model: nn.Module) -> list:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
|
||||
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
torch_p.grad = torch.rand_like(torch_p)
|
||||
# avoid inconsistent grad and param dtype error
|
||||
orig_p = p.data
|
||||
p.data = torch_p.grad.clone().to(g_dtype)
|
||||
p.grad = p.data
|
||||
p.data = orig_p
|
||||
force_assign_grad(p, g_dtype, torch_p.grad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
|
||||
|
Reference in New Issue
Block a user