mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 02:54:10 +00:00
fix
This commit is contained in:
parent
65c487499e
commit
2aa6e44355
@ -13,8 +13,9 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
|
||||
|
||||
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("use_async", [True, False])
|
||||
@parameterize("size_per_shard", [16, 128])
|
||||
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||
def check_torch_ddp_checkpointIO(shard: bool, use_async: bool, size_per_shard: int):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
@ -39,7 +40,11 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
if not use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||
if use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
|
||||
dist.barrier()
|
||||
|
Loading…
Reference in New Issue
Block a user