mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
This commit is contained in:
@@ -11,11 +11,14 @@ from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW
|
||||
from colossalai.nn.optimizer.galore import get_galore_param_groups
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
@@ -113,7 +116,9 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
|
||||
assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
|
||||
|
||||
|
||||
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
|
||||
def build_model_from_hybrid_plugin(
|
||||
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
|
||||
):
|
||||
use_lazy_init = False
|
||||
if "use_lazy_init" in test_config:
|
||||
use_lazy_init = test_config.pop("use_lazy_init")
|
||||
@@ -125,8 +130,25 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
org_model = org_model.cuda()
|
||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||
if sharded_optim_class == DistGaloreAwamW:
|
||||
# Disable clipping and block-wise quantization
|
||||
org_optimizer = optim_class(
|
||||
get_galore_param_groups(org_model, weight_decay=0, rank=4),
|
||||
lr=1e-3,
|
||||
percentile_clipping=101,
|
||||
block_wise=False,
|
||||
min_8bit_size=1e10,
|
||||
)
|
||||
sharded_optimizer = sharded_optim_class(
|
||||
get_galore_param_groups(sharded_model, weight_decay=0, rank=4),
|
||||
lr=1e-3,
|
||||
percentile_clipping=101,
|
||||
block_wise=False,
|
||||
min_8bit_size=1e10,
|
||||
)
|
||||
else:
|
||||
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||
criterion = loss_fn
|
||||
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
@@ -143,6 +165,32 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||
)
|
||||
|
||||
|
||||
def build_model_from_low_level_zero_plugin(
|
||||
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
|
||||
):
|
||||
use_lazy_init = False
|
||||
if "use_lazy_init" in test_config:
|
||||
use_lazy_init = test_config.pop("use_lazy_init")
|
||||
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
org_model = model_fn()
|
||||
sharded_model = copy.deepcopy(org_model)
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
|
||||
org_model = org_model.cuda()
|
||||
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||
criterion = loss_fn
|
||||
|
||||
plugin = LowLevelZeroPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
|
||||
|
||||
|
||||
def run_forward_backward_with_hybrid_plugin(
|
||||
org_model: Module,
|
||||
sharded_model: Module,
|
||||
@@ -209,6 +257,44 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
return org_loss, org_output, sharded_loss, sharded_output
|
||||
|
||||
|
||||
def run_forward_backward_with_low_level_zero_plugin(
|
||||
org_model: Module,
|
||||
sharded_model: Module,
|
||||
sharded_optimizer: Optimizer,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
criterion: Callable,
|
||||
booster: Booster,
|
||||
):
|
||||
get_accelerator().get_current_device()
|
||||
org_model.cuda()
|
||||
sharded_model.cuda()
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
loss = criterion(outputs)
|
||||
return loss
|
||||
|
||||
data = data_gen_fn()
|
||||
|
||||
# data = {
|
||||
# k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||
# }
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
sharded_model.train()
|
||||
sharded_output = sharded_model(**data)
|
||||
sharded_loss = criterion(sharded_output)
|
||||
sharded_optimizer.backward(sharded_loss)
|
||||
|
||||
org_model.train()
|
||||
org_output = org_model(**data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
return org_loss, org_output, sharded_loss, sharded_output
|
||||
|
||||
|
||||
def check_output_hidden_state(
|
||||
org_output: Tensor,
|
||||
sharded_output: Tensor,
|
||||
@@ -312,6 +398,9 @@ def check_grad(
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
# if verbose and dist.get_rank() == 0:
|
||||
# print("shard_weight", shard_weight)
|
||||
# print("org_grad", org_grad)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
|
Reference in New Issue
Block a user