mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[zero] support shard optimizer state dict of zero (#4194)
* support shard optimizer of zero * polish code * support sync grad manually
This commit is contained in:
@@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
new_model = resnet18()
|
||||
@@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -62,3 +61,7 @@ def run_dist(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_low_level_zero_checkpointIO():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_low_level_zero_checkpointIO()
|
||||
|
Reference in New Issue
Block a user