mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Feature] auto-cast optimizers to distributed version (#5746)
* auto-cast optimizers to distributed * fix galore casting * logger --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -27,7 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
@@ -1179,6 +1179,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
# TODO: Support Galore + ZeRO
|
||||
zero_stage = self.zero_stage
|
||||
zero_config = deepcopy(self.zero_config)
|
||||
|
||||
# Replace with distributed implementation if exists
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
||||
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
|
||||
zero_config["partition_grad"] = False
|
||||
|
Reference in New Issue
Block a user