mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io
This commit is contained in:
@@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
|
||||
model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
new_model.state_dict(),
|
||||
False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
@@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
model.state_dict(),
|
||||
False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
|
Reference in New Issue
Block a user