mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
This commit is contained in:
@@ -14,7 +14,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
||||
from tests.test_optimizer._utils import check_optim_states, force_assign_grad, run_bert_test, setup_param_groups
|
||||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
@@ -49,29 +49,6 @@ def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group):
|
||||
raise e
|
||||
|
||||
|
||||
def setup_param_groups(bert_model: nn.Module) -> list:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
|
||||
def force_assign_grad(p, g_dtype, grad=None):
|
||||
"""avoid inconsistent grad and param dtype error"""
|
||||
orig_p = p.data
|
||||
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad
|
||||
p.grad = p.data
|
||||
p.data = orig_p
|
||||
|
||||
|
||||
def set_dist_grad(
|
||||
dist_module: nn.Module,
|
||||
torch_model: nn.Module,
|
||||
|
Reference in New Issue
Block a user