[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -35,7 +35,10 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("use_safetensors", [False, True])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
@parameterize("use_async", [False, True])
def exam_state_dict_with_origin(
placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool
):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
@@ -70,7 +73,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
"",
(model_size / 3),
use_safetensors=use_safetensors,
use_async=use_async,
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@@ -83,7 +89,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
@parameterize("use_async", [False, True])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_flash_attention = True if tp_size > 1 else False
@@ -124,14 +133,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(
optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
@@ -155,8 +168,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
loss = criterion(output[output_key])
booster.backward(loss, new_optimizer)
new_optimizer.step()
booster.save_model(new_model, model_ckpt_path, shard=shard)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
with shared_tempdir() as new_tempdir:
model_ckpt_path = f"{new_tempdir}/model"
optimizer_ckpt_path = f"{new_tempdir}/optimizer"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors"
booster.save_model(new_model, model_ckpt_path, shard=shard, use_async=use_async)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
def exam_lazy_from_pretrained():