diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index 12b70cc04..2b17eef98 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -10,9 +10,8 @@ from colossalai.booster import Booster if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def compare_nested_dict(dict1, dict2): @@ -43,7 +42,9 @@ def compare_nested_dict(dict1, dict2): return True -def check_torch_fsdp_ckpt(): +@parameterize("shard", [True, False]) +@parameterize("use_async", [True, False]) +def check_torch_fsdp_ckpt(shard: bool, use_async: bool): model = resnet18() plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) @@ -64,10 +65,13 @@ def check_torch_fsdp_ckpt(): with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optim_ckpt_path = f"{tempdir}/optimizer" - + if not use_async: + model_ckpt_path = f"{model_ckpt_path}.pt" + if use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=False) + booster.save_model(fsdp_model, model_ckpt_path, shard=shard, use_async=use_async) booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) full_msd = fsdp_model.state_dict() @@ -100,44 +104,6 @@ def check_torch_fsdp_ckpt(): outputs_sec = fsdp_model(inputs) assert criterion(outputs_sec) == criterion(outputs) - with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" - optim_ckpt_path = f"{tempdir}/optimizer" - - run_model() - - booster.save_model(fsdp_model, model_ckpt_path, shard=True) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=True) - - full_msd = fsdp_model.unwrap().state_dict() - full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) - - import copy - - sharded_osd = copy.deepcopy(full_osd) - - run_model() - - full_msd_updated = fsdp_model.unwrap().state_dict() - full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) - - # cost much time led to timeout - # assert not compare_nested_dict(full_osd_updated, sharded_osd) - # assert not compare_nested_dict(full_msd_updated, full_msd) - outputs_first = fsdp_model(inputs) - assert criterion(outputs_first) != criterion(outputs) - - booster.load_model(fsdp_model, model_ckpt_path) - booster.load_optimizer(optimizer, optim_ckpt_path) - - full_msd_restore = fsdp_model.unwrap().state_dict() - sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) - - assert compare_nested_dict(sharded_osd, sharded_osd_restore) - assert compare_nested_dict(full_msd_restore, full_msd) - outputs_sec = fsdp_model(inputs) - assert criterion(outputs_sec) == criterion(outputs) - def run_dist(rank, world_size, port): # init dist env