[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)

* Refractor and cleanup using common helper funcs. Tests passed

* Update comments

* Fix relative import

* Fix param fetching bug
This commit is contained in:
Wenxuan Tan
2025-02-11 23:42:34 -06:00
committed by GitHub
parent 5c09d726a6
commit ec73f1b5e2
8 changed files with 142 additions and 298 deletions

View File

@@ -8,6 +8,7 @@ from torch.optim import Adam, AdamW
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import force_assign_grad, setup_param_groups
_ALLOWED_OPTIM_DEVICES = [
(FusedAdam, torch.device("cuda:0")),
@@ -26,29 +27,11 @@ _ALLOWED_P_G_TYPES = [
N_STEPS = 3
def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
torch_p.grad = torch.rand_like(torch_p)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
force_assign_grad(p, g_dtype, torch_p.grad)
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)