From c0a033700c7027286ecd8a7bcbaafc6f794323ad Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 20 Sep 2023 18:29:37 +0800 Subject: [PATCH] [shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758) * fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs --- .../naive_amp/mixed_precision_optimizer.py | 17 +++- colossalai/booster/booster.py | 2 +- colossalai/booster/plugin/gemini_plugin.py | 23 +++-- .../booster/plugin/hybrid_parallel_plugin.py | 22 +++-- .../booster/plugin/low_level_zero_plugin.py | 46 +++------- colossalai/booster/plugin/torch_ddp_plugin.py | 63 +++++++++++--- .../booster/plugin/torch_fsdp_plugin.py | 16 ++-- .../checkpoint_io/checkpoint_io_base.py | 6 -- .../checkpoint_io/general_checkpoint_io.py | 11 --- .../hybrid_parallel_checkpoint_io.py | 83 +++++-------------- colossalai/checkpoint_io/utils.py | 9 -- colossalai/zero/gemini/gemini_ddp.py | 4 - colossalai/zero/low_level/low_level_optim.py | 6 ++ .../test_plugins_huggingface_compatibility.py | 4 +- 14 files changed, 141 insertions(+), 171 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 6a192cc5c..501a843f6 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -2,7 +2,7 @@ from typing import Dict, List import torch from torch import Tensor -from torch.nn import Parameter +from torch.nn import Module, Parameter from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper @@ -152,3 +152,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper): if p is working_param: continue working_param.data.copy_(p.data) + + def update_master_params(self, model: Module): + # Update master params from working params + with torch.no_grad(): + for p in model.parameters(): + if (p is None) or (p not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[p] + master_param.data.copy_(p.data) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()} + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()} diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 2aee72cbf..8d6b0b42e 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -139,7 +139,7 @@ class Booster: if self.plugin and not self.plugin.control_device(): # transform model for accelerator - model = self.accelerator.configure(model) + model = self.accelerator.configure_model(model) if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): # transform model for mixed precision diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 83a00d4ee..abf3a907b 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -44,6 +44,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): The model should be unwrapped in self.load_model via ModelWrapper.unwrap. As there is communication when getting state dict, model.state_dict() must be called on all processes. """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors) @@ -53,24 +54,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Load model from checkpoint with automatic unwrapping. The model should be unwrapped in self.load_model via ModelWrapper.unwrap. """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): """ Save unsharded optimizer state dict to checkpoint. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. The saving process will only be executed by master rank. """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" state_dict = optimizer.state_dict() if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): """ Loading unsharded optimizer from checkpoint file. For each process, only loading optimizer states of parameters it controls. """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" super().load_unsharded_optimizer(optimizer, checkpoint) def save_sharded_model( @@ -86,6 +90,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Save sharded model. As there is communication when getting state dict, model.state_dict() must be called on all processes. """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return @@ -111,7 +116,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - save_config_file(model.module, checkpoint_path) + save_config_file(model.unwrap(), checkpoint_path) logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -124,17 +129,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO): """ Load shard model, load model from multiple files. """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int ): """ Save sharded optimizer state dict to checkpoint folder. As there is communication when getting state dict, this must be called on all processes. """ - - assert isinstance(optimizer, GeminiOptimizer) + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -176,12 +181,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO): f"index located at {save_index_file}." ) - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): + def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): """ Loading sharded optimizer from checkpoint folder, with index file given. For each process, only loading optimizer states of parameters it controls. """ - + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): logging.error(f"Provided path ({checkpoint_index_file}) should be a file") @@ -383,7 +388,7 @@ class GeminiPlugin(DPPluginBase): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c1693fa8d..46930887b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,7 @@ import random from contextlib import nullcontext from functools import partial +from types import MethodType from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np @@ -165,6 +166,15 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): init_pipeline_optimizer(optim, model) super().__init__(optim) + def update_master_params(self, model: Module): + pass + + def get_working_to_master_map(self): + return None + + def get_master_to_working_map(self): + return None + class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): def __init__( @@ -466,9 +476,6 @@ class HybridParallelPlugin(PipelinePluginBase): max_norm=self.max_norm, **self.amp_config, ) - self.checkpoint_io.link_master_and_working_param( - optimizer.working_to_master_map, optimizer.master_to_working_map - ) else: optimizer = HybridParallelNaiveOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info @@ -488,10 +495,8 @@ class HybridParallelPlugin(PipelinePluginBase): **self.zero_config, **self.amp_config, ) - self.checkpoint_io.link_master_and_working_param( - optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param - ) - + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline( @@ -567,8 +572,7 @@ class HybridParallelPlugin(PipelinePluginBase): ) def get_checkpoint_io(self) -> CheckpointIO: - self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - return self.checkpoint_io + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 86adee7fe..457c720f6 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import ( save_param_groups, save_state_dict, sharded_optimizer_loading_epilogue, - unwrap_optimizer, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device @@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) - def unwrap(self): - # TODO(ver217): this is a workaround for loading model - return self - class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): @@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): checkpoint (str): Path to save checkpoint gather_dtensor (bool): Whether to gather_dtensor, not used """ - + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" # the `state_dict` in LowLevelZeroOptimizer has communication # if only the master rank collect state_dict and save, # the communication on each rank would not match @@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file that store state tensors """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): index_file_path (str): Path to the index file prefix (str): Not used. """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = unwrap_optimizer(optimizer) + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!" + optimizer = optimizer.unwrap() # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) @@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): v_list = v.split(v.numel() // self.coordinator.world_size) state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() load_states_into_optimizer(optimizer, state_dict, id_map) - sharded_optimizer_loading_epilogue(optimizer) - def save_unsharded_model( - self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool - ): - assert isinstance(model, LowLevelZeroModel) - super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) - - def save_sharded_model( - self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False, - ): - assert isinstance(model, LowLevelZeroModel) - super().save_sharded_model( - model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors - ) - - def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): - assert isinstance(model, LowLevelZeroModel) - super().load_unsharded_model(model.module, checkpoint, strict) + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + super().load_unsharded_model(model, checkpoint, strict) model.update_master_params() def load_sharded_model( self, - model: LowLevelZeroModel, + model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False, load_sub_module: bool = True, ): - assert isinstance(model, LowLevelZeroModel) - super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 30d34e7dd..41d7c0635 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ - Load model from checkpoint with automatic unwrapping. + Load model from checkpoint. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - return super().load_unsharded_model(model, checkpoint, strict=strict) + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_unsharded_optimizer(optimizer, checkpoint) + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) @@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def save_sharded_model( self, - model: nn.Module, + model: ModelWrapper, checkpoint_path: str, gather_dtensor: bool = True, prefix: Optional[str] = None, @@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): """ Save model to checkpoint but only on master process. """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): - super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) + super().save_sharded_model( + model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) + + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: str, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + """ + Load model from sharded checkpoint. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module) def save_sharded_optimizer( self, - optimizer: Optimizer, + optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, ): """ - Save optimizer to checkpoint but only on master process. + Save optimizer to sharded checkpoint but only on master process. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): - super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) + + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + prefix: Optional[str] = None, + ): + """ + Load optimizer from sharded checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d12b784b4..1e3762b79 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -39,31 +39,35 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + model = model.unwrap() checkpoint = utils.load_state_dict(checkpoint) model.load_state_dict(checkpoint) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" checkpoint = utils.load_state_dict(checkpoint) fsdp_model = optimizer.unwrap_model() sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + model = model.unwrap() cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): full_model_state = model.state_dict() utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ - assert isinstance(optimizer, FSDPOptimizerWrapper) + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" fsdp_model = optimizer.unwrap_model() full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index f8ce8f4e5..780117598 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -87,9 +87,6 @@ class CheckpointIO(ABC): # return the origin model instead of the unwrapped model origin_model = model - if isinstance(model, ModelWrapper): - model = model.unwrap() - if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: @@ -134,9 +131,6 @@ class CheckpointIO(ABC): use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ - if isinstance(model, ModelWrapper): - model = model.unwrap() - if shard: self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index b0e593e90..a652d9b45 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,8 +8,6 @@ from typing import Optional import torch.nn as nn from torch.optim import Optimizer -from colossalai.interface import OptimizerWrapper - from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -28,7 +26,6 @@ from .utils import ( shard_model_checkpoint, shard_optimizer_checkpoint, sharded_optimizer_loading_epilogue, - unwrap_optimizer, ) __all__ = ["GeneralCheckpointIO"] @@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO): Load sharded optimizer with the given path to index file. """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = unwrap_optimizer(optimizer) - # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) @@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO): - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = unwrap_optimizer(optimizer) - if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 18c59a880..41e53b3b3 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path from shutil import rmtree -from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union +from typing import Dict, Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist @@ -13,7 +13,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): self.tp_size = dist.get_world_size(tp_group) self.use_zero = zero_stage > 0 self.verbose = verbose - self.working_to_master_map = None - self.master_to_working_map = None self.coordinator = DistCoordinator() @staticmethod @@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): def save_sharded_model( self, - model: nn.Module, + model: ModelWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, @@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): f"index located at {final_index_file_path}." ) - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): """ Load sharded model with the given path to index file of checkpoint folder. @@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since params on same device might be stored in different files. """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model_before_wrapping = model # backup for model before wrapping + model = model.unwrap() # Check whether the checkpoint uses safetensors. use_safetensors = False @@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): _load(extra_state_key) # Update master params if mixed-precision training is enabled. - with torch.no_grad(): - if self.working_to_master_map is not None: - for param in model.parameters(): - if (param is None) or (id(param) not in self.working_to_master_map): - continue - master_param = self.working_to_master_map[id(param)] - if self.use_zero: - # master_param is sharded under Zero setting - padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size - if padding_size > 0: - padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - else: - padded_param = param.data.view(-1) - sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] - master_param.data.copy_(sharded_param.data) - else: - master_param.data.copy_(param.data) + model_before_wrapping.update_master_params() if self.verbose: logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): use_zero=self.use_zero, dp_group=self.dp_group, tp_group=self.tp_group, - master_to_working_map=self.master_to_working_map, + master_to_working_map=optimizer.get_master_to_working_map(), size_per_shard=size_per_shard, ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) @@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): checkpoint_index_file (str): Path to the index file of checkpointing folder. prefix (str): Not used. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None @@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # When Zero is used, the mapped parameter objects should be fp32 master parameters. # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) id_map[param_id] = param # Read checkpoint index file. @@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Then shard the loaded optimizer states if using tp/zero. for param, state in optimizer.optim.state.items(): device = param.device - if self.master_to_working_map is not None: - working_param = self.master_to_working_map[id(param)] + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] else: working_param = param original_shape = optimizer.param_info["param2shape"][id(working_param)] @@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def link_master_and_working_param( - self, - working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], - master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor], - ): - """ - Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. - This mapping can only be created when mixied precision is used. - The created mappings should be mappings from integer parameter addresses to parameter objects. - - Args: - working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. - master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. - """ - self.working_to_master_map = dict() - for k, v in working_to_master_map.items(): - if isinstance(k, torch.Tensor): - self.working_to_master_map[id(k)] = v - elif isinstance(k, int): - self.working_to_master_map[k] = v - else: - raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" - ) - - self.master_to_working_map = dict() - for k, v in master_to_working_map.items(): - if isinstance(k, torch.Tensor): - self.master_to_working_map[id(k)] = v - elif isinstance(k, int): - self.master_to_working_map[k] = v - else: - raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" - ) - @staticmethod def gather_from_sharded_optimizer_state( state: OrderedDict, diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index c22b76dd4..d2f4a0bca 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -11,7 +11,6 @@ import torch import torch.nn as nn from torch.optim import Optimizer -from colossalai.interface import OptimizerWrapper from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz # ====================================== # Helper classes and functions for saving shard file # ====================================== -def unwrap_optimizer(optimizer: OptimizerWrapper): - """ - Unwrap a wrapped optimizer. - This method should be used before saving/loading it to/from sharded checkpoints. - """ - - unwrapped_optim = optimizer.optim - return unwrapped_optim class StateDictSharder: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 580b497ce..0ba9e53cf 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -186,10 +186,6 @@ class GeminiDDP(ModelWrapper): for p in params_to_ignore: p._ddp_to_ignore = True - def unwrap(self): - # as save/load state dict is overwrited, only return self - return self - def _get_non_persistent_buffers_set( self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True ): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 1bf5302ef..72df93ace 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -648,3 +648,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.working_to_master_param + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.master_to_working_param diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index c3c30e666..a6f67e0d7 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) if plugin_type == "gemini": - check_state_dict_equal( - model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False - ) + check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) else: check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) dist.barrier()