[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)

* Refractor and cleanup using common helper funcs. Tests passed

* Update comments

* Fix relative import

* Fix param fetching bug
This commit is contained in:
Wenxuan Tan
2025-02-11 23:42:34 -06:00
committed by GitHub
parent 5c09d726a6
commit ec73f1b5e2
8 changed files with 142 additions and 298 deletions

View File

@@ -14,7 +14,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_optim_states, run_bert_test
from tests.test_optimizer._utils import check_optim_states, force_assign_grad, run_bert_test, setup_param_groups
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
@@ -49,29 +49,6 @@ def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group):
raise e
def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def force_assign_grad(p, g_dtype, grad=None):
"""avoid inconsistent grad and param dtype error"""
orig_p = p.data
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad
p.grad = p.data
p.data = orig_p
def set_dist_grad(
dist_module: nn.Module,
torch_model: nn.Module,