mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[moe] support optimizer checkpoint (#5015)
* Refactor MoE Manager setup method * unshard optim ckpt * optim io * update transformer version * update requirements * update ckpt * update ckpt * update ckpt * fix engine * fix engine
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoeCheckpintIO
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
@@ -322,8 +322,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoeCheckpintIO:
|
||||
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
@@ -359,9 +359,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config,
|
||||
)
|
||||
self.checkpoint_io.link_master_and_working_param(
|
||||
optimizer.working_to_master_map, optimizer.master_to_working_map
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
|
Reference in New Issue
Block a user