[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-05-24 17:24:16 +08:00
committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
13 changed files with 61 additions and 31 deletions

View File

@@ -27,7 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@@ -1179,6 +1179,10 @@ class HybridParallelPlugin(PipelinePluginBase):
# TODO: Support Galore + ZeRO
zero_stage = self.zero_stage
zero_config = deepcopy(self.zero_config)
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_config["partition_grad"] = False