This commit is contained in:
wangbluo 2024-11-19 17:29:22 +08:00
parent 0009a4dbbe
commit 65c487499e

View File

@ -10,9 +10,8 @@ from colossalai.booster import Booster
if version.parse(torch.__version__) >= version.parse("1.12.0"): if version.parse(torch.__version__) >= version.parse("1.12.0"):
from colossalai.booster.plugin import TorchFSDPPlugin from colossalai.booster.plugin import TorchFSDPPlugin
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def compare_nested_dict(dict1, dict2): def compare_nested_dict(dict1, dict2):
@ -43,7 +42,9 @@ def compare_nested_dict(dict1, dict2):
return True return True
def check_torch_fsdp_ckpt(): @parameterize("shard", [True, False])
@parameterize("use_async", [True, False])
def check_torch_fsdp_ckpt(shard: bool, use_async: bool):
model = resnet18() model = resnet18()
plugin = TorchFSDPPlugin() plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
@ -64,10 +65,13 @@ def check_torch_fsdp_ckpt():
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer" optim_ckpt_path = f"{tempdir}/optimizer"
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
run_model() run_model()
booster.save_model(fsdp_model, model_ckpt_path, shard=False) booster.save_model(fsdp_model, model_ckpt_path, shard=shard, use_async=use_async)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
full_msd = fsdp_model.state_dict() full_msd = fsdp_model.state_dict()
@ -100,44 +104,6 @@ def check_torch_fsdp_ckpt():
outputs_sec = fsdp_model(inputs) outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs) assert criterion(outputs_sec) == criterion(outputs)
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer"
run_model()
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
full_msd = fsdp_model.unwrap().state_dict()
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
import copy
sharded_osd = copy.deepcopy(full_osd)
run_model()
full_msd_updated = fsdp_model.unwrap().state_dict()
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
# cost much time led to timeout
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
# assert not compare_nested_dict(full_msd_updated, full_msd)
outputs_first = fsdp_model(inputs)
assert criterion(outputs_first) != criterion(outputs)
booster.load_model(fsdp_model, model_ckpt_path)
booster.load_optimizer(optimizer, optim_ckpt_path)
full_msd_restore = fsdp_model.unwrap().state_dict()
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
assert compare_nested_dict(full_msd_restore, full_msd)
outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
# init dist env # init dist env