mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 02:54:10 +00:00
fix
This commit is contained in:
parent
0009a4dbbe
commit
65c487499e
@ -10,9 +10,8 @@ from colossalai.booster import Booster
|
|||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def compare_nested_dict(dict1, dict2):
|
def compare_nested_dict(dict1, dict2):
|
||||||
@ -43,7 +42,9 @@ def compare_nested_dict(dict1, dict2):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_torch_fsdp_ckpt():
|
@parameterize("shard", [True, False])
|
||||||
|
@parameterize("use_async", [True, False])
|
||||||
|
def check_torch_fsdp_ckpt(shard: bool, use_async: bool):
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
plugin = TorchFSDPPlugin()
|
plugin = TorchFSDPPlugin()
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
@ -64,10 +65,13 @@ def check_torch_fsdp_ckpt():
|
|||||||
with shared_tempdir() as tempdir:
|
with shared_tempdir() as tempdir:
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||||
|
if not use_async:
|
||||||
|
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||||
|
if use_async:
|
||||||
|
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||||
run_model()
|
run_model()
|
||||||
|
|
||||||
booster.save_model(fsdp_model, model_ckpt_path, shard=False)
|
booster.save_model(fsdp_model, model_ckpt_path, shard=shard, use_async=use_async)
|
||||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
|
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)
|
||||||
|
|
||||||
full_msd = fsdp_model.state_dict()
|
full_msd = fsdp_model.state_dict()
|
||||||
@ -100,44 +104,6 @@ def check_torch_fsdp_ckpt():
|
|||||||
outputs_sec = fsdp_model(inputs)
|
outputs_sec = fsdp_model(inputs)
|
||||||
assert criterion(outputs_sec) == criterion(outputs)
|
assert criterion(outputs_sec) == criterion(outputs)
|
||||||
|
|
||||||
with shared_tempdir() as tempdir:
|
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
|
||||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
|
||||||
|
|
||||||
run_model()
|
|
||||||
|
|
||||||
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
|
|
||||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
|
|
||||||
|
|
||||||
full_msd = fsdp_model.unwrap().state_dict()
|
|
||||||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
sharded_osd = copy.deepcopy(full_osd)
|
|
||||||
|
|
||||||
run_model()
|
|
||||||
|
|
||||||
full_msd_updated = fsdp_model.unwrap().state_dict()
|
|
||||||
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
|
||||||
|
|
||||||
# cost much time led to timeout
|
|
||||||
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
|
|
||||||
# assert not compare_nested_dict(full_msd_updated, full_msd)
|
|
||||||
outputs_first = fsdp_model(inputs)
|
|
||||||
assert criterion(outputs_first) != criterion(outputs)
|
|
||||||
|
|
||||||
booster.load_model(fsdp_model, model_ckpt_path)
|
|
||||||
booster.load_optimizer(optimizer, optim_ckpt_path)
|
|
||||||
|
|
||||||
full_msd_restore = fsdp_model.unwrap().state_dict()
|
|
||||||
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
|
||||||
|
|
||||||
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
|
|
||||||
assert compare_nested_dict(full_msd_restore, full_msd)
|
|
||||||
outputs_sec = fsdp_model(inputs)
|
|
||||||
assert criterion(outputs_sec) == criterion(outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
|
Loading…
Reference in New Issue
Block a user