[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-05-24 17:24:16 +08:00
committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
13 changed files with 61 additions and 31 deletions

View File

@@ -552,7 +552,7 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
sharded_optimizer,
criterion,
booster,
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor)
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, Adafactor)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster

View File

@@ -416,7 +416,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, CAME)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster

View File

@@ -306,8 +306,8 @@ def check_dist_galore(rank, world_size, port):
global coordinator
coordinator = DistCoordinator()
run_dist_galore_basic()
coordinator.print_on_master("Basic backward tests passed")
# run_dist_galore_basic()
# coordinator.print_on_master("Basic backward tests passed")
coordinator.print_on_master("Skipping forward-backward tests due to SVD instability")
# run_dist_galore_fwd_bwd()
@@ -319,7 +319,7 @@ def check_dist_galore(rank, world_size, port):
)
for config in test_config:
try:
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW)
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=GaLoreAdamW8bit)
except Exception as e:
print(e)
dist.barrier()

View File

@@ -289,7 +289,7 @@ def check_dist_lamb(rank, world_size, port):
run_dist_lamb_fwd_bwd()
coordinator.print_on_master("Forward-backward tests passed")
run_bert_test(optim_class=Lamb, sharded_optim_class=DistributedLamb)
run_bert_test(optim_class=Lamb, sharded_optim_class=Lamb)
print(f"rank {rank} tests passed :)")

View File

@@ -17,7 +17,7 @@ from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer import GaLoreAdamW8bit
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -130,7 +130,7 @@ def build_model_from_hybrid_plugin(
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
if sharded_optim_class == DistGaloreAwamW:
if optim_class == GaLoreAdamW8bit:
# Disable clipping and block-wise quantization
org_optimizer = optim_class(
get_galore_param_groups(org_model, weight_decay=0, rank=4),