[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)

* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [optim] add distributed came (#5526)

* test CAME under LowLevelZeroOptimizer wrapper

* test CAME TP row and col pass

* test CAME zero pass

* came zero add master and worker param id convert

* came zero test pass

* came zero test pass

* test distributed came passed

* reform code, Modify some expressions and add comments

* minor fix of test came

* minor fix of dist_came and test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix of dist_came and test

* rebase dist-optim

* rebase dist-optim

* fix remaining comments

* add test dist came using booster api

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [optim] Distributed Adafactor (#5484)

* [feature] solve conflict; update optimizer readme;

* [feature] update optimize readme;

* [fix] fix testcase;

* [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel);

* [feature] Add transformers_bert model zoo in testcase;

* [feature] add user documentation to docs/source/feature.

* [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam;

* [feature] modify user documentation;

* [fix] fix readme format issue;

* [fix] add zero=0 in testcase; cached augment in dict;

* [fix] fix percision issue;

* [feature] add distributed rms;

* [feature] remove useless comment in testcase;

* [fix] Remove useless test; open zero test; remove fp16 test in bert exam;

* [feature] Extract distributed rms function;

* [feature] add booster + lowlevelzeroPlugin in test;

* [feature] add Start_with_booster_API case in md; add Supporting Information in md;

* [fix] Also remove state movement in base adafactor;

* [feature] extract factor function;

* [feature] add LowLevelZeroPlugin test;

* [fix] add tp=False and zero=True in logic;

* [fix] fix use zero logic;

* [feature] add row residue logic in column parallel factor;

* [feature] add check optim state func;

* [feature] Remove duplicate logic;

* [feature] update optim state check func and percision test bug;

* [fix] update/fix optim state; Still exist percision issue;

* [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info;

* [feature] removed print & comments in utils;

* [feature] uodate Readme;

* [feature] add LowLevelZeroPlugin test with Bert model zoo;

* [fix] fix logic in _rms;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] remove comments in testcase;

* [feature] add zh-Han Readme;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676)

* [feature] daily update;

* [fix] fix dist came;

* [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo;

* [fix] open rms; fix low level zero test; fix dist came test function name;

* [fix] remove redundant test;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

* update comments

* add initial distributed galore

* add initial distributed galore

* add galore set param utils; change setup_distributed interface

* projected grad precision passed

* basic precision tests passed

* tests passed; located svd precision issue in fwd-bwd; banned these tests

* Plugin DP + TP tests passed

* move get_shard_dim to d_tensor

* add comments

* remove useless files

* remove useless files

* fix zero typo

* improve interface

* remove moe changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* fix deepcopy

* update came & adafactor to main

* fix param map

* fix typo

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692)


Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
This commit is contained in:
Edenzzzz
2024-05-14 13:52:45 +08:00
committed by GitHub
parent 393c8f5b7f
commit 43995ee436
30 changed files with 4821 additions and 42 deletions

View File

@@ -8,7 +8,10 @@ from types import MethodType
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch
import torch.distributed
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import _get_default_group
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@@ -28,6 +31,8 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer
@@ -428,13 +433,31 @@ class LowLevelZeroPlugin(DPPluginBase):
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
optim = optimizer.optim
is_zero = dp_size > 1 and zero_stage > 0
dp_group = _get_default_group() # Use the whole world
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map()
padding_map = optimizer.get_param_padding_map()
optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool: