[zero] support shard optimizer state dict of zero (#4194)

* support shard optimizer of zero

* polish code

* support sync grad manually
This commit is contained in:
LuGY
2023-07-11 18:03:13 +08:00
committed by Hongxin Liu
parent dd7cc58299
commit 1a49a5ea00
4 changed files with 239 additions and 68 deletions

View File

@@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
dist.barrier()
new_model = resnet18()
@@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
def run_dist(rank, world_size, port):
@@ -62,3 +61,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_low_level_zero_checkpointIO():
spawn(run_dist, 2)
if __name__ == "__main__":
test_low_level_zero_checkpointIO()