[booster] torch fsdp fix ckpt (#3788)

This commit is contained in:
wukong1992
2023-05-23 16:58:45 +08:00
committed by GitHub
parent 9265f2d4d7
commit 6b305a99d6
5 changed files with 230 additions and 186 deletions

View File

@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
@@ -5,30 +6,18 @@ import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._init_utils import ProcessGroupType
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@@ -36,7 +25,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def __set_model_optim_state(
self,
model,
state_dict_type,
state_dict_config,
optim_state_dict_config,
):
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_sharded_model(self, model: nn.Module, checkpoint: str):
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def save_sharded_model(self, model: nn.Module, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
full_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
model.load_state_dict(full_state_dict)
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
"""
Load Optimizer from checkpoint with automatic unwrapping.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
optim.load_state_dict(optim_full_state_dict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
model_state_dict = model.state_dict()
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
model_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(model_state_dict, checkpoint)
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
FullOptimStateDictConfig(rank0_only=True))
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(optim_state_dict, checkpoint)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class TorchFSDPModel(ModelWrapper):
@@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
return self.module
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
def unwrap_model(self) -> nn.Module:
return self.model
class TorchFSDPPlugin(DPPluginBase):
@@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) >= version.parse('1.12.0'):
def __init__(
self,
@@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
):
super().__init__()
@@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
def __init__(
self,
process_group: ProcessGroupType = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_parameters=ignored_parameters)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
model = model.cuda()
# wrap the model with PyTorch FSDP
model = TorchFSDPModel(model, **self.fsdp_kwargs)
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return model, optimizer, criterion, dataloader, lr_scheduler
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True