mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)
* add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fix
This commit is contained in:
@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
@@ -292,6 +292,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
@@ -460,7 +461,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
**_kwargs)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return None
|
||||
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user