[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang 2023-09-20 18:29:37 +08:00 committed by GitHub
parent 7b9b86441f
commit c0a033700c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 141 additions and 171 deletions

View File

@ -2,7 +2,7 @@ from typing import Dict, List
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Module, Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
@ -152,3 +152,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if p is working_param: if p is working_param:
continue continue
working_param.data.copy_(p.data) working_param.data.copy_(p.data)
def update_master_params(self, model: Module):
# Update master params from working params
with torch.no_grad():
for p in model.parameters():
if (p is None) or (p not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[p]
master_param.data.copy_(p.data)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}

View File

@ -139,7 +139,7 @@ class Booster:
if self.plugin and not self.plugin.control_device(): if self.plugin and not self.plugin.control_device():
# transform model for accelerator # transform model for accelerator
model = self.accelerator.configure(model) model = self.accelerator.configure_model(model)
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision # transform model for mixed precision

View File

@ -44,6 +44,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
The model should be unwrapped in self.load_model via ModelWrapper.unwrap. The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, model.state_dict() must be called on all processes. As there is communication when getting state dict, model.state_dict() must be called on all processes.
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True) state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master(): if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors) save_state_dict(state_dict, checkpoint, use_safetensors)
@ -53,24 +54,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Load model from checkpoint with automatic unwrapping. Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap. The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict) super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
""" """
Save unsharded optimizer state dict to checkpoint. Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank. The saving process will only be executed by master rank.
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict() state_dict = optimizer.state_dict()
if self.coordinator.is_master(): if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
""" """
Loading unsharded optimizer from checkpoint file. Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls. For each process, only loading optimizer states of parameters it controls.
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint) super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model( def save_sharded_model(
@ -86,6 +90,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Save sharded model. Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes. As there is communication when getting state dict, model.state_dict() must be called on all processes.
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return return
@ -111,7 +116,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path) save_config_file(model.unwrap(), checkpoint_path)
logging.info( logging.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
@ -124,17 +129,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
Load shard model, load model from multiple files. Load shard model, load model from multiple files.
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer( def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
): ):
""" """
Save sharded optimizer state dict to checkpoint folder. Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes. As there is communication when getting state dict, this must be called on all processes.
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
assert isinstance(optimizer, GeminiOptimizer)
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -176,12 +181,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
f"index located at {save_index_file}." f"index located at {save_index_file}."
) )
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
""" """
Loading sharded optimizer from checkpoint folder, with index file given. Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls. For each process, only loading optimizer states of parameters it controls.
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file): if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file") logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
@ -383,7 +388,7 @@ class GeminiPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer( optimizer = GeminiOptimizer(
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
) )
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,6 +1,7 @@
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np import numpy as np
@ -165,6 +166,15 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__(optim) super().__init__(optim)
def update_master_params(self, model: Module):
pass
def get_working_to_master_map(self):
return None
def get_master_to_working_map(self):
return None
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__( def __init__(
@ -466,9 +476,6 @@ class HybridParallelPlugin(PipelinePluginBase):
max_norm=self.max_norm, max_norm=self.max_norm,
**self.amp_config, **self.amp_config,
) )
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else: else:
optimizer = HybridParallelNaiveOptimizer( optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
@ -488,10 +495,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**self.zero_config, **self.zero_config,
**self.amp_config, **self.amp_config,
) )
self.checkpoint_io.link_master_and_working_param( # inject update_master_params
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param model.update_master_params = MethodType(optimizer.update_master_params, model)
)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline( def execute_pipeline(
@ -567,8 +572,7 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError

View File

@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import (
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
unwrap_optimizer,
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint (str): Path to save checkpoint checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used gather_dtensor (bool): Whether to gather_dtensor, not used
""" """
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
# the `state_dict` in LowLevelZeroOptimizer has communication # the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save, # if only the master rank collect state_dict and save,
# the communication on each rank would not match # the communication on each rank would not match
@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
prefix (str): Perfix of file to save prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors size_per_shard (int): Max file size of each file that store state tensors
""" """
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file_path (str): Path to the index file index_file_path (str): Path to the index file
prefix (str): Not used. prefix (str): Not used.
""" """
# If optimizer is wrapped, unwrap it. assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
if isinstance(optimizer, OptimizerWrapper): optimizer = optimizer.unwrap()
optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file. # Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
v_list = v.split(v.numel() // self.coordinator.world_size) v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map) load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model( def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
): super().load_unsharded_model(model, checkpoint, strict)
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params() model.update_master_params()
def load_sharded_model( def load_sharded_model(
self, self,
model: LowLevelZeroModel, model: ModelWrapper,
checkpoint_index_file: Path, checkpoint_index_file: Path,
strict: bool = False, strict: bool = False,
use_safetensors: bool = False, use_safetensors: bool = False,
load_sub_module: bool = True, load_sub_module: bool = True,
): ):
assert isinstance(model, LowLevelZeroModel) assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params() model.update_master_params()

View File

@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
""" """
Load model from checkpoint with automatic unwrapping. Load model from checkpoint.
""" """
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
return super().load_unsharded_model(model, checkpoint, strict=strict) super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model( def save_sharded_model(
self, self,
model: nn.Module, model: ModelWrapper,
checkpoint_path: str, checkpoint_path: str,
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: str,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
def save_sharded_optimizer( def save_sharded_optimizer(
self, self,
optimizer: Optimizer, optimizer: OptimizerWrapper,
checkpoint: str, checkpoint: str,
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
): ):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to sharded checkpoint but only on master process.
""" """
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):

View File

@ -39,31 +39,35 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint) checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
checkpoint = utils.load_state_dict(checkpoint) checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model() fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd) optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
model = model.unwrap()
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict() full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
assert isinstance(optimizer, FSDPOptimizerWrapper) assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model() fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

View File

@ -87,9 +87,6 @@ class CheckpointIO(ABC):
# return the origin model instead of the unwrapped model # return the origin model instead of the unwrapped model
origin_model = model origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if index_file_exists: if index_file_exists:
self.load_sharded_model(model, index_file_path, strict) self.load_sharded_model(model, index_file_path, strict)
else: else:
@ -134,9 +131,6 @@ class CheckpointIO(ABC):
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
""" """
if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard: if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else: else:

View File

@ -8,8 +8,6 @@ from typing import Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
from .utils import ( from .utils import (
@ -28,7 +26,6 @@ from .utils import (
shard_model_checkpoint, shard_model_checkpoint,
shard_optimizer_checkpoint, shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
unwrap_optimizer,
) )
__all__ = ["GeneralCheckpointIO"] __all__ = ["GeneralCheckpointIO"]
@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO):
Load sharded optimizer with the given path to index file. Load sharded optimizer with the given path to index file.
""" """
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file. # Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO):
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
""" """
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return

View File

@ -3,7 +3,7 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -13,7 +13,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.tp_size = dist.get_world_size(tp_group) self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0 self.use_zero = zero_stage > 0
self.verbose = verbose self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
@staticmethod @staticmethod
@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
def save_sharded_model( def save_sharded_model(
self, self,
model: nn.Module, model: ModelWrapper,
checkpoint: str, checkpoint: str,
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: Optional[str] = None, prefix: Optional[str] = None,
@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model = model.unwrap()
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}." f"index located at {final_index_file_path}."
) )
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
""" """
Load sharded model with the given path to index file of checkpoint folder. Load sharded model with the given path to index file of checkpoint folder.
@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
strict (bool, optional): For name matching during loading state_dict. Defaults to False. strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files. This argument should be manually set to False since params on same device might be stored in different files.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
# Check whether the checkpoint uses safetensors. # Check whether the checkpoint uses safetensors.
use_safetensors = False use_safetensors = False
@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(extra_state_key) _load(extra_state_key)
# Update master params if mixed-precision training is enabled. # Update master params if mixed-precision training is enabled.
with torch.no_grad(): model_before_wrapping.update_master_params()
if self.working_to_master_map is not None:
for param in model.parameters():
if (param is None) or (id(param) not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[id(param)]
if self.use_zero:
# master_param is sharded under Zero setting
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
if padding_size > 0:
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padded_param = param.data.view(-1)
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
master_param.data.copy_(sharded_param.data)
else:
master_param.data.copy_(param.data)
if self.verbose: if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str): Perfix of file to save prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors size_per_shard (int): Max file size of each file shard that store state tensors
""" """
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_zero=self.use_zero, use_zero=self.use_zero,
dp_group=self.dp_group, dp_group=self.dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map, master_to_working_map=optimizer.get_master_to_working_map(),
size_per_shard=size_per_shard, size_per_shard=size_per_shard,
) )
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
checkpoint_index_file (str): Path to the index file of checkpointing folder. checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used. prefix (str): Not used.
""" """
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param( def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When Zero is used, the mapped parameter objects should be fp32 master parameters. # When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {} id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups: for pg in optimizer.optim.param_groups:
for param in pg["params"]: for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param id_map[param_id] = param
# Read checkpoint index file. # Read checkpoint index file.
@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for param in pg["params"]: for param in pg["params"]:
if param is None: if param is None:
continue continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map: if param_id not in weight_map:
continue continue
filename = weight_map[param_id] filename = weight_map[param_id]
@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then shard the loaded optimizer states if using tp/zero. # Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items(): for param, state in optimizer.optim.state.items():
device = param.device device = param.device
if self.master_to_working_map is not None: if master_to_working_map is not None:
working_param = self.master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)] original_shape = optimizer.param_info["param2shape"][id(working_param)]
@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(
self,
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
The created mappings should be mappings from integer parameter addresses to parameter objects.
Args:
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
"""
self.working_to_master_map = dict()
for k, v in working_to_master_map.items():
if isinstance(k, torch.Tensor):
self.working_to_master_map[id(k)] = v
elif isinstance(k, int):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
if isinstance(k, torch.Tensor):
self.master_to_working_map[id(k)] = v
elif isinstance(k, int):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
@staticmethod @staticmethod
def gather_from_sharded_optimizer_state( def gather_from_sharded_optimizer_state(
state: OrderedDict, state: OrderedDict,

View File

@ -11,7 +11,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# ====================================== # ======================================
# Helper classes and functions for saving shard file # Helper classes and functions for saving shard file
# ====================================== # ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
"""
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
"""
unwrapped_optim = optimizer.optim
return unwrapped_optim
class StateDictSharder: class StateDictSharder:

View File

@ -186,10 +186,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore: for p in params_to_ignore:
p._ddp_to_ignore = True p._ddp_to_ignore = True
def unwrap(self):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set( def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
): ):

View File

@ -648,3 +648,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if padding_size > 0: if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size]) working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.master_to_working_param

View File

@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
if plugin_type == "gemini": if plugin_type == "gemini":
check_state_dict_equal( check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False
)
else: else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
dist.barrier() dist.barrier()