[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-05-24 17:24:16 +08:00
committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
13 changed files with 61 additions and 31 deletions

View File

@@ -32,7 +32,7 @@ from colossalai.checkpoint_io.utils import (
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer
@@ -437,6 +437,10 @@ class LowLevelZeroPlugin(DPPluginBase):
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()
# Replace with the distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False