[gemini] support amp o3 for gemini (#4872)

* [gemini] support no reuse fp16 chunk

* [gemini] support no master weight for optim

* [gemini] support no master weight for gemini ddp

* [test] update gemini tests

* [test] update gemini tests

* [plugin] update gemini plugin

* [test] fix gemini checkpointio test

* [test] fix gemini checkpoint io
This commit is contained in:
Hongxin Liu
2023-10-12 10:39:08 +08:00
committed by GitHub
parent c1fab951e7
commit df63564184
15 changed files with 222 additions and 114 deletions

View File

@@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
model.state_dict(only_rank_0=False, prefix="module.module."),
new_model.state_dict(),
False,
ignore_dtype=True,
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
@@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
new_model.state_dict(only_rank_0=False, prefix="module.module."),
model.state_dict(),
False,
ignore_dtype=True,
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)