mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
This commit is contained in:
279
colossalai/nn/optimizer/distributed_galore.py
Normal file
279
colossalai/nn/optimizer/distributed_galore.py
Normal file
@@ -0,0 +1,279 @@
|
||||
""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
|
||||
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
|
||||
|
||||
from .galore import GaLoreProjector, make_low_rank_buffer
|
||||
|
||||
__all__ = ["DistributedGalore"]
|
||||
# Mark sharded dimension
|
||||
|
||||
|
||||
class DistGaloreAwamW(DistributedOptim, Optimizer2State):
|
||||
r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW.
|
||||
It largely compresses gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr.
|
||||
Supports Tensor Parallel and ZeRO stage 1 and 2 via booster and plugin.
|
||||
Proposed in `GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection`
|
||||
https://arxiv.org/abs/2403.03507
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its norm. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-6)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
|
||||
nbits: Number of bits for quantization optim states. Only 32 and 8 are supported.
|
||||
min_8bit_size (`int`, defaults to 4096):
|
||||
The minimum number of elements of the parameter tensors for 8-bit optimization.
|
||||
percentile_clipping (`int`, defaults to 100):
|
||||
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
|
||||
block_wise (`bool`, defaults to `True`):
|
||||
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
|
||||
is_paged (`bool`, defaults to `False`):
|
||||
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
nbits=8,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
is_paged=False,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
nbits,
|
||||
None,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
is_paged=is_paged,
|
||||
)
|
||||
self.tp_size = 1
|
||||
self.dp_size = 1
|
||||
self.is_dist = {}
|
||||
proj_none = all(["rank" not in group for group in self.param_groups])
|
||||
if proj_none:
|
||||
warnings.warn(
|
||||
"Will not apply GaLore as rank isn't in any param group. If you forgot to, try get_galore_param_groups"
|
||||
)
|
||||
|
||||
# Default from the paper
|
||||
for group in self.param_groups:
|
||||
if "rank" in group:
|
||||
group["update_proj_gap"] = group.get("update_proj_gap", 200)
|
||||
group["proj_type"] = group.get("proj_type", "std")
|
||||
group["scale"] = group.get("scale", 0.25)
|
||||
|
||||
def setup_distributed(
|
||||
self,
|
||||
tp_group: Optional[dist.ProcessGroup] = None,
|
||||
dp_group: Optional[dist.ProcessGroup] = None,
|
||||
shard_to_working_param: Optional[Dict] = {},
|
||||
padding_map: Optional[Dict] = defaultdict(int),
|
||||
is_zero: Optional[bool] = False,
|
||||
):
|
||||
"""Setup process groups for TP and ZeRO 2.
|
||||
Arguments:
|
||||
tp_group (dist.ProcessGroup): Tensor Parallel process group
|
||||
dp_group (dist.ProcessGroup): ZeRO 2 process group
|
||||
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
|
||||
This maps from id(view) to working params used in forward & backward.
|
||||
padding_map (Dict): Padding size of each param from ZeRO's param store. Required if ZeRO is used.
|
||||
is_zero (bool): Whether to use ZeRO 2.
|
||||
"""
|
||||
assert dist.is_initialized(), "You forgot to initialized distributed backend..."
|
||||
|
||||
self.tp_group = tp_group
|
||||
self.dp_group = dp_group
|
||||
if tp_group is not None:
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
if dp_group is not None:
|
||||
self.dp_size = dist.get_world_size(dp_group)
|
||||
|
||||
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
|
||||
self.is_zero = is_zero and self.dp_size > 1
|
||||
self.padding_map = padding_map if padding_map is not None else defaultdict(int)
|
||||
if is_zero:
|
||||
assert self.padding_map is not defaultdict(
|
||||
int
|
||||
), "We can't do SVD without knowing ZeRO's per-param padding size"
|
||||
self.distributed_on = self.tp_size > 0 or self.dp_size > 0
|
||||
|
||||
# Cache working param layout
|
||||
self.shard_dim = {}
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
# w/o ZeRO: master param = working param
|
||||
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
|
||||
if id(p) not in self.padding_map:
|
||||
self.padding_map[id(p)] = 0
|
||||
|
||||
self.is_dist[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)])
|
||||
if is_distributed_tensor(self.shard_to_working_param[id(p)]):
|
||||
self.shard_dim[id(p)] = get_shard_dim_1d(self.shard_to_working_param[id(p)])
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
if not self.initialized:
|
||||
self.check_overrides()
|
||||
self.to_gpu()
|
||||
self.initialized = True
|
||||
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
scale=group["scale"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
# decoupled weight decay
|
||||
if "weight_decay" in group and group["weight_decay"] > 0:
|
||||
group["weight_decay_saved"] = group["weight_decay"]
|
||||
group["weight_decay"] = 0
|
||||
|
||||
grad = p.grad
|
||||
working_shape = list(self.shard_to_working_param[id(p)].shape)
|
||||
padding = self.padding_map[id(p)]
|
||||
|
||||
# All-gather grads for projection step
|
||||
if self.distributed_on:
|
||||
# Gather for ZeRO 1 & 2 implementation don't retain full grads
|
||||
if self.is_zero:
|
||||
# (m, n).flatten().chunk(dp_size) equals to (m / dp_size, n).flatten()
|
||||
working_shape[0] //= self.dp_size
|
||||
# Gather grads for projection
|
||||
if state["step"] % group["update_proj_gap"] == 0:
|
||||
all_grads = [
|
||||
torch.empty_like(grad, dtype=p.grad.dtype, device=p.grad.device)
|
||||
for _ in range(self.dp_size)
|
||||
]
|
||||
dist.all_gather(all_grads, grad, self.dp_group)
|
||||
grad = torch.cat(all_grads)
|
||||
# To working param shape
|
||||
if padding > 0:
|
||||
grad = grad[:-padding]
|
||||
working_shape[0] *= self.dp_size
|
||||
grad = grad.reshape(working_shape) # unflatten
|
||||
|
||||
# Gather TP grads
|
||||
if self.is_dist[id(p)] and state["step"] % group["update_proj_gap"] == 0:
|
||||
all_grads = [
|
||||
torch.empty_like(grad, dtype=p.grad.dtype, device=p.grad.device)
|
||||
for _ in range(self.tp_size)
|
||||
]
|
||||
dist.all_gather(all_grads, grad.contiguous(), self.tp_group)
|
||||
grad = torch.cat(all_grads, dim=self.shard_dim[id(p)])
|
||||
|
||||
# Compute SVD. Will use a subset of singular vectors when grads are sharded.
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
# Re-shard gathered grads after SVD
|
||||
if self.distributed_on and state["step"] % group["update_proj_gap"] == 0:
|
||||
# TP
|
||||
if self.is_dist[id(p)]:
|
||||
grad = grad.chunk(self.tp_size, dim=self.shard_dim[id(p)])[dist.get_rank(self.tp_group)]
|
||||
# ZeRO
|
||||
# TODO: this might not work with padding, e.g. (3, 3) with dp size 2
|
||||
# Need extra logic in ZeRO to pad nRows/nCols to be divisible by dp_size
|
||||
if self.is_zero:
|
||||
grad = grad.chunk(self.dp_size)[dist.get_rank(self.dp_group)]
|
||||
grad = grad.contiguous() # avoid bitsandbytes update error
|
||||
|
||||
working_shape = grad.shape
|
||||
# To flattended master param shape
|
||||
grad = self.to_master_shape(grad, padding)
|
||||
make_low_rank_buffer(p, grad)
|
||||
|
||||
if "state1" not in state:
|
||||
self.init_state(group, p, gindex, pindex)
|
||||
|
||||
self.prefetch_state(p)
|
||||
self.update_step(group, p, gindex, pindex)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Project Back to working param shape
|
||||
if "rank" in group:
|
||||
# Unpad
|
||||
if self.is_zero:
|
||||
if padding > 0:
|
||||
p.data = p.data[:-padding]
|
||||
p.data = p.data.reshape(working_shape)
|
||||
|
||||
p.data = state["projector"].project_back(p.data)
|
||||
# Re-flatten grads for ZeRO
|
||||
p.data = self.to_master_shape(p.data, padding)
|
||||
p.data = p.saved_data.add_(p.data)
|
||||
|
||||
# apply decoupled weight decay
|
||||
if "weight_decay_saved" in group:
|
||||
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
|
||||
group["weight_decay"] = group["weight_decay_saved"]
|
||||
del group["weight_decay_saved"]
|
||||
|
||||
if self.is_paged:
|
||||
# all paged operation are asynchronous, we need
|
||||
# to sync to make sure all tensors are in the right state
|
||||
torch.cuda.synchronize()
|
||||
return loss
|
||||
|
||||
def to_master_shape(self, data, padding):
|
||||
"""Pad to master (optimizer) param shape"""
|
||||
if not self.is_zero:
|
||||
return data
|
||||
data = data.view(-1)
|
||||
if padding > 0:
|
||||
data = F.pad(data, [0, padding])
|
||||
return data
|
||||
|
||||
def __del__(self):
|
||||
"""Avoid buffer memory leak"""
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if hasattr(p, "saved_data"):
|
||||
del p.saved_data
|
Reference in New Issue
Block a user