[hotfix] fix typo in hybrid parallel io (#4697)

This commit is contained in:
Baizhou Zhang
2023-09-12 17:32:19 +08:00
committed by GitHub
parent 8844691f4b
commit d8ceeac14e
3 changed files with 7 additions and 7 deletions

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
@@ -513,7 +513,7 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs)
def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]: