diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 125a9ccca..fc04f3ecd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule @@ -513,7 +513,7 @@ class HybridParallelPlugin(PipelinePluginBase): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 07b1f81da..e1aa6543e 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,6 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO -from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO +from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index fef5b0d16..6eee3ace0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -39,7 +39,7 @@ except ImportError: _EXTRA_STATE_KEY_SUFFIX = '_extra_state' -class HypridParallelCheckpointIO(GeneralCheckpointIO): +class HybridParallelCheckpointIO(GeneralCheckpointIO): """ CheckpointIO for Hybrid Parallel Training. @@ -136,7 +136,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): param_id = param_info['param2id'][id(working_param)] original_shape = param_info['param2shape'][id(working_param)] - state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, working_param, original_shape=original_shape, dp_group=dp_group, @@ -189,7 +189,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) control_saving = (self.tp_rank == 0) @@ -385,7 +385,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, dp_group=self.dp_group,