mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[zero] hotfix master param sync (#4618)
* [zero] add method to update master params * [zero] update zero plugin * [plugin] update low level zero plugin
This commit is contained in:
@@ -14,6 +14,7 @@ from colossalai.testing import (
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
# stage 1 and 2 process the optimizer/mode the same way
|
||||
@@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||
assert p_id in working_param_id_set
|
||||
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert torch.equal(working_shard,
|
||||
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
|
Reference in New Issue
Block a user