mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[ckpt] Add async ckpt api (#6136)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -63,10 +63,15 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
|
||||
booster.save_model(
|
||||
bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
|
||||
bert_model,
|
||||
pretrained_path,
|
||||
True,
|
||||
True,
|
||||
"",
|
||||
(model_size / 3),
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
|
||||
|
||||
@@ -119,7 +124,12 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_model(
|
||||
model,
|
||||
model_ckpt_path,
|
||||
shard=shard,
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
dist.barrier()
|
||||
|
Reference in New Issue
Block a user