[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-05-24 17:24:16 +08:00
committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
13 changed files with 61 additions and 31 deletions

View File

@@ -306,8 +306,8 @@ def check_dist_galore(rank, world_size, port):
global coordinator
coordinator = DistCoordinator()
run_dist_galore_basic()
coordinator.print_on_master("Basic backward tests passed")
# run_dist_galore_basic()
# coordinator.print_on_master("Basic backward tests passed")
coordinator.print_on_master("Skipping forward-backward tests due to SVD instability")
# run_dist_galore_fwd_bwd()
@@ -319,7 +319,7 @@ def check_dist_galore(rank, world_size, port):
)
for config in test_config:
try:
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW)
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=GaLoreAdamW8bit)
except Exception as e:
print(e)
dist.barrier()