mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Feature] auto-cast optimizers to distributed version (#5746)
* auto-cast optimizers to distributed * fix galore casting * logger --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -17,7 +17,7 @@ from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW
|
||||
from colossalai.nn.optimizer import GaLoreAdamW8bit
|
||||
from colossalai.nn.optimizer.galore import get_galore_param_groups
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
@@ -130,7 +130,7 @@ def build_model_from_hybrid_plugin(
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
org_model = org_model.cuda()
|
||||
if sharded_optim_class == DistGaloreAwamW:
|
||||
if optim_class == GaLoreAdamW8bit:
|
||||
# Disable clipping and block-wise quantization
|
||||
org_optimizer = optim_class(
|
||||
get_galore_param_groups(org_model, weight_decay=0, rank=4),
|
||||
|
Reference in New Issue
Block a user