mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 18:43:58 +00:00
fix
This commit is contained in:
parent
0009a4dbbe
commit
65c487499e
@ -10,9 +10,8 @@ from colossalai.booster import Booster
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def compare_nested_dict(dict1, dict2):
|
||||
@ -43,7 +42,9 @@ def compare_nested_dict(dict1, dict2):
|
||||
return True
|
||||
|
||||
|
||||
def check_torch_fsdp_ckpt():
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("use_async", [True, False])
|
||||
def check_torch_fsdp_ckpt(shard: bool, use_async: bool):
|
||||
model = resnet18()
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
@ -64,10 +65,13 @@ def check_torch_fsdp_ckpt():
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
if not use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||
if use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=False)
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=shard, use_async=use_async)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
|
||||
|
||||
full_msd = fsdp_model.state_dict()
|
||||
@ -100,44 +104,6 @@ def check_torch_fsdp_ckpt():
|
||||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
|
||||
|
||||
full_msd = fsdp_model.unwrap().state_dict()
|
||||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
import copy
|
||||
|
||||
sharded_osd = copy.deepcopy(full_osd)
|
||||
|
||||
run_model()
|
||||
|
||||
full_msd_updated = fsdp_model.unwrap().state_dict()
|
||||
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
# cost much time led to timeout
|
||||
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
|
||||
# assert not compare_nested_dict(full_msd_updated, full_msd)
|
||||
outputs_first = fsdp_model(inputs)
|
||||
assert criterion(outputs_first) != criterion(outputs)
|
||||
|
||||
booster.load_model(fsdp_model, model_ckpt_path)
|
||||
booster.load_optimizer(optimizer, optim_ckpt_path)
|
||||
|
||||
full_msd_restore = fsdp_model.unwrap().state_dict()
|
||||
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
|
||||
assert compare_nested_dict(full_msd_restore, full_msd)
|
||||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
|
Loading…
Reference in New Issue
Block a user