[ckpt] Add async ckpt api (#6136)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-11-15 18:19:16 +08:00
committed by Hongxin Liu
parent d4a436051d
commit 8e08c27e19
12 changed files with 174 additions and 86 deletions

View File

@@ -63,10 +63,15 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
booster.save_model(
bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors
bert_model,
pretrained_path,
True,
True,
"",
(model_size / 3),
use_safetensors=use_safetensors,
)
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@@ -119,7 +124,12 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()

View File

@@ -26,9 +26,10 @@ from tests.kit.model_zoo import model_zoo
# only test 2 is fine
@clear_cache_before_run()
@parameterize("stage", [2])
@parameterize("shard", [True, False])
@parameterize("shard", [False, True])
@parameterize("offload", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
@parameterize("use_async", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin)
model = resnet18()
@@ -41,13 +42,26 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
if not shard and not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
use_async=use_async,
)
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_model = resnet18()
@@ -71,6 +85,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache()