mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Feature] auto-cast optimizers to distributed version (#5746)
* auto-cast optimizers to distributed * fix galore casting * logger --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -552,7 +552,7 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
|
||||
sharded_optimizer,
|
||||
criterion,
|
||||
booster,
|
||||
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor)
|
||||
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, Adafactor)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
|
@@ -416,7 +416,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
||||
sharded_optimizer,
|
||||
criterion,
|
||||
booster,
|
||||
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
|
||||
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, CAME)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
|
@@ -306,8 +306,8 @@ def check_dist_galore(rank, world_size, port):
|
||||
global coordinator
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
run_dist_galore_basic()
|
||||
coordinator.print_on_master("Basic backward tests passed")
|
||||
# run_dist_galore_basic()
|
||||
# coordinator.print_on_master("Basic backward tests passed")
|
||||
|
||||
coordinator.print_on_master("Skipping forward-backward tests due to SVD instability")
|
||||
# run_dist_galore_fwd_bwd()
|
||||
@@ -319,7 +319,7 @@ def check_dist_galore(rank, world_size, port):
|
||||
)
|
||||
for config in test_config:
|
||||
try:
|
||||
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW)
|
||||
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=GaLoreAdamW8bit)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
dist.barrier()
|
||||
|
@@ -289,7 +289,7 @@ def check_dist_lamb(rank, world_size, port):
|
||||
run_dist_lamb_fwd_bwd()
|
||||
coordinator.print_on_master("Forward-backward tests passed")
|
||||
|
||||
run_bert_test(optim_class=Lamb, sharded_optim_class=DistributedLamb)
|
||||
run_bert_test(optim_class=Lamb, sharded_optim_class=Lamb)
|
||||
print(f"rank {rank} tests passed :)")
|
||||
|
||||
|
||||
|
@@ -17,7 +17,7 @@ from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW
|
||||
from colossalai.nn.optimizer import GaLoreAdamW8bit
|
||||
from colossalai.nn.optimizer.galore import get_galore_param_groups
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
@@ -130,7 +130,7 @@ def build_model_from_hybrid_plugin(
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
org_model = org_model.cuda()
|
||||
if sharded_optim_class == DistGaloreAwamW:
|
||||
if optim_class == GaLoreAdamW8bit:
|
||||
# Disable clipping and block-wise quantization
|
||||
org_optimizer = optim_class(
|
||||
get_galore_param_groups(org_model, weight_decay=0, rank=4),
|
||||
|
Reference in New Issue
Block a user