mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
142 lines
5.8 KiB
Markdown
142 lines
5.8 KiB
Markdown
# Distributed Optimizers
|
|
|
|
Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github.com/duanjunwen), [Renjie Mao](https://github.com/chongqichuizi875)
|
|
|
|
**Related Paper**
|
|
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
|
|
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047)
|
|
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507)
|
|
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
|
|
|
|
## Introduction
|
|
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins.
|
|
## Optimizers
|
|
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
|
|
|
|
## API Reference
|
|
|
|
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
|
|
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
|
|
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
|
|
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
|
|
|
|
## Hands-On Practice
|
|
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
|
|
### step 1. Import libraries
|
|
|
|
```python
|
|
from transformers import LlamaModel, LlamaConfig
|
|
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin import HybridParallelPlugin
|
|
import colossalai
|
|
import torch
|
|
```
|
|
|
|
### step 2. Initialize Distributed Environment and Parallism Group
|
|
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
|
|
|
|
```python
|
|
colossalai.launch_from_torch()
|
|
```
|
|
|
|
### step 3. Initialize Module and Optimizer
|
|
Build our model. We created an MLP using two Linear Layer.
|
|
|
|
```python
|
|
# Init Llama from huggingface
|
|
configuration = LlamaConfig()
|
|
model = LlamaModel(configuration).cuda()
|
|
criterion = lambda x: x.mean()
|
|
dist_optim = DistributedAdaFactor(model.parameters())
|
|
|
|
```
|
|
|
|
### step 4.Init Booster
|
|
|
|
```python
|
|
plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True)
|
|
booster = Booster(plugin=plugin)
|
|
# You should also pass in your own dataset.
|
|
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion)
|
|
```
|
|
### step 5.Train Your Model
|
|
```python
|
|
steps = 10
|
|
for step in range(steps):
|
|
input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int)
|
|
attention_mask = input_ids.clone()
|
|
outputs = model(input_ids.cuda(), attention_mask.cuda())
|
|
loss = criterion(outputs.last_hidden_state)
|
|
booster.backward(loss, dist_optim)
|
|
dist_optim.step()
|
|
dist_optim.zero_grad()
|
|
```
|
|
### GaLore special handling
|
|
For GaLore, we need to specify projection rank for each parameter group and quantization & paged optimizer params. Please refer to bitandbytes for quantization details. Support for ZeRO is underway.
|
|
```python
|
|
from colossalai.nn.optimizer.galore import get_galore_param_groups
|
|
from colossalai.nn.optimizer import DistGaloreAwamW
|
|
optim = DistGaloreAwamW(
|
|
get_galore_param_groups(model, decay=1e-2, rank=8),
|
|
lr=lr,
|
|
betas=(beta1, beta2),
|
|
eps=eps,
|
|
nbits=8,
|
|
percentile_clipping=100,
|
|
block_wise=True,
|
|
min_8bit_size=4096,
|
|
)
|
|
```
|
|
|
|
## Plugin compatibility
|
|
<table>
|
|
<tr>
|
|
<th nowrap="nowrap">Model/Feature</th>
|
|
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
|
|
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
|
|
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
|
|
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
|
|
</tr>
|
|
<tr>
|
|
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
</tr>
|
|
<tr>
|
|
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
</tr>
|
|
<tr>
|
|
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
<td nowrap="nowrap" align="center">✔️</td>
|
|
</tr>
|
|
<tr>
|
|
<td nowrap="nowrap">Gemini<br />Plugin</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
</tr>
|
|
<tr>
|
|
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
<td nowrap="nowrap" align="center">❌</td>
|
|
</tr>
|
|
<tr>
|
|
<td colspan="39"></td>
|
|
</tr>
|
|
</table>
|
|
|
|
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|