[zero] hotfix master param sync (#4618)

* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
This commit is contained in:
Hongxin Liu
2023-09-05 15:04:02 +08:00
committed by GitHub
parent aaeb520ce3
commit 807e01a4ba
5 changed files with 122 additions and 45 deletions

View File

@@ -14,6 +14,7 @@ from colossalai.testing import (
rerun_if_address_is_in_use,
spawn,
)
from colossalai.zero import LowLevelZeroOptimizer
# stage 1 and 2 process the optimizer/mode the same way
@@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(working_shard,
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)