mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[booster] torch fsdp fix ckpt (#3788)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -5,30 +6,18 @@ import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._init_utils import ProcessGroupType
|
||||
from torch.distributed.fsdp.api import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import _FSDPPolicy
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
@@ -36,7 +25,7 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
@@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def __set_model_optim_state(
|
||||
self,
|
||||
model,
|
||||
state_dict_type,
|
||||
state_dict_config,
|
||||
optim_state_dict_config,
|
||||
):
|
||||
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: str):
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
|
||||
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
|
||||
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
full_state_dict = self.load_state_dict(checkpoint)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
full_state_dict = self.load_state_dict(checkpoint)
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
|
||||
full_state_dict = model.state_dict()
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
model.load_state_dict(full_state_dict)
|
||||
|
||||
def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
|
||||
"""
|
||||
Load Optimizer from checkpoint with automatic unwrapping.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
optim_full_state_dict = self.load_state_dict(checkpoint)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
optim_full_state_dict = self.load_state_dict(checkpoint)
|
||||
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
optim.load_state_dict(optim_full_state_dict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
full_model_state = model.state_dict()
|
||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
model_state_dict = model.state_dict()
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
|
||||
model_state_dict = model.state_dict()
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
self.save_checkpoint(model_state_dict, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper)
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
|
||||
FullOptimStateDictConfig(rank0_only=True))
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
self.save_checkpoint(optim_state_dict, checkpoint)
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class TorchFSDPModel(ModelWrapper):
|
||||
@@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
|
||||
self.module = FSDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module.module
|
||||
return self.module
|
||||
|
||||
|
||||
class FSDPOptimizerWrapper(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optimizer: Optimizer, model: nn.Module):
|
||||
self.model = model
|
||||
super().__init__(optimizer)
|
||||
|
||||
def unwrap_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
|
||||
class TorchFSDPPlugin(DPPluginBase):
|
||||
@@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
See https://pytorch.org/docs/stable/fsdp.html for details.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||
torch.__version__) < version.parse('2.0.0'):
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
sync_module_states: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states)
|
||||
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_group: ProcessGroupType = None,
|
||||
sharding_strategy: Optional[ShardingStrategy] = None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = False,
|
||||
use_orig_params: bool = False,
|
||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=use_orig_params,
|
||||
ignored_parameters=ignored_parameters)
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
@@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
lr_scheduler: LRScheduler = None,
|
||||
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||
|
||||
model = model.cuda()
|
||||
# wrap the model with PyTorch FSDP
|
||||
model = TorchFSDPModel(model, **self.fsdp_kwargs)
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
|
||||
if not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
if not isinstance(optimizer, FSDPOptimizerWrapper):
|
||||
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user