1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-02 13:45:36 +00:00
ColossalAI/colossalai/nn/optimizer/distributed_lamb.py
Edenzzzz 43995ee436
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor ()
* [feat] Add distributed lamb; minor fixes in DeviceMesh ()

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [hotfix] Improve tester precision by removing ZeRO on vanilla lamb ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [optim] add distributed came ()

* test CAME under LowLevelZeroOptimizer wrapper

* test CAME TP row and col pass

* test CAME zero pass

* came zero add master and worker param id convert

* came zero test pass

* came zero test pass

* test distributed came passed

* reform code, Modify some expressions and add comments

* minor fix of test came

* minor fix of dist_came and test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix of dist_came and test

* rebase dist-optim

* rebase dist-optim

* fix remaining comments

* add test dist came using booster api

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [optim] Distributed Adafactor ()

* [feature] solve conflict; update optimizer readme;

* [feature] update optimize readme;

* [fix] fix testcase;

* [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel);

* [feature] Add transformers_bert model zoo in testcase;

* [feature] add user documentation to docs/source/feature.

* [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam;

* [feature] modify user documentation;

* [fix] fix readme format issue;

* [fix] add zero=0 in testcase; cached augment in dict;

* [fix] fix percision issue;

* [feature] add distributed rms;

* [feature] remove useless comment in testcase;

* [fix] Remove useless test; open zero test; remove fp16 test in bert exam;

* [feature] Extract distributed rms function;

* [feature] add booster + lowlevelzeroPlugin in test;

* [feature] add Start_with_booster_API case in md; add Supporting Information in md;

* [fix] Also remove state movement in base adafactor;

* [feature] extract factor function;

* [feature] add LowLevelZeroPlugin test;

* [fix] add tp=False and zero=True in logic;

* [fix] fix use zero logic;

* [feature] add row residue logic in column parallel factor;

* [feature] add check optim state func;

* [feature] Remove duplicate logic;

* [feature] update optim state check func and percision test bug;

* [fix] update/fix optim state; Still exist percision issue;

* [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info;

* [feature] removed print & comments in utils;

* [feature] uodate Readme;

* [feature] add LowLevelZeroPlugin test with Bert model zoo;

* [fix] fix logic in _rms;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] remove comments in testcase;

* [feature] add zh-Han Readme;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; ()

* [feature] daily update;

* [fix] fix dist came;

* [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo;

* [fix] open rms; fix low level zero test; fix dist came test function name;

* [fix] remove redundant test;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit ()

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

* update comments

* add initial distributed galore

* add initial distributed galore

* add galore set param utils; change setup_distributed interface

* projected grad precision passed

* basic precision tests passed

* tests passed; located svd precision issue in fwd-bwd; banned these tests

* Plugin DP + TP tests passed

* move get_shard_dim to d_tensor

* add comments

* remove useless files

* remove useless files

* fix zero typo

* improve interface

* remove moe changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* fix deepcopy

* update came & adafactor to main

* fix param map

* fix typo

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove one buggy test case from dist_adafactor for now ()


Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 13:52:45 +08:00

182 lines
7.5 KiB
Python

# Disclaimer: Modified from https://github.com/NUS-HPC-AI-Lab/pytorch-lamb/blob/master/optim/lamb.py
from typing import Dict, Optional
import torch
import torch.distributed as dist
from colossalai.interface.optimizer import DistributedOptim
from colossalai.tensor.d_tensor import is_distributed_tensor
__all__ = ["DistributedLamb"]
class DistributedLamb(DistributedOptim):
r"""Implements the Lamb algorithm, with extra support for ZeRO 2 and Tensor Parallel.
Proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
It's recommended to use this with HybridParallelPlugin/ZeRO plugin and booster,
which will take care of setup_distributed.
Example with 4 devices:
>>> optim = DistributedLamb(model.parameters(), lr=1e-3)
>>> proc_mesh = ProcessGroupMesh(tp_size, zero_size)
>>> tp_group = proc_mesh.get_group_along_axis(0)
>>> dp_group = proc_mesh.get_group_along_axis(1)
>>> optim.setup_distributed(tp_group, dp_group)
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0,
bias_correction=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
# self.setup_distributed(tp_group, dp_group)
self.shard_to_working_param = {}
self.tp_size = self.dp_size = 1
self.is_zero = False
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super().__init__(params, defaults)
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map=None,
is_zero: Optional[bool] = False,
):
"""Assign process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO 2 process group
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
This maps from id(view) to working params used in forward & backward.
padding_map: An empty interface placeholder
is_zero (bool): Whether to use ZeRO 2.
"""
self.tp_group = tp_group
self.dp_group = dp_group
if tp_group is not None:
self.tp_size = dist.get_world_size(tp_group)
if dp_group is not None:
self.dp_size = dist.get_world_size(dp_group)
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
self.is_zero = is_zero
self.is_dist = {}
# Cache parameter layout
for group in self.param_groups:
for p in group["params"]:
# w/o ZeRO: master param = working param
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
self.is_dist[p] = (
is_distributed_tensor(p)
if self.dp_size <= 1
else is_distributed_tensor(self.shard_to_working_param.get(id(p), None))
)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
scaled_lr = group["lr"]
if group["bias_correction"]:
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Apply debiasing to lr to avoid broadcast
scaled_lr *= (bias_correction2**0.5) / bias_correction1
# exp_avg.div_(bias_correction1)
# exp_avg_sq.div_(bias_correction2)
update = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
if group["weight_decay"] != 0:
update.add_(p.data, alpha=group["weight_decay"])
# Compute global layer-wise trust ratio
if self.is_dist[p] or self.is_zero:
p_local = p
g_sum = (update**2).sum()
if self.dp_size > 1 and self.is_zero:
# ZeRO 2 doesn't shard param. Compute full param norm w/o communication.
dist.all_reduce(g_sum, group=self.dp_group)
p_local = self.shard_to_working_param[id(p)]
w_sum = (p_local**2).sum()
sums = torch.stack([w_sum, g_sum])
# Get global l2 norms
if self.tp_size > 1:
dist.all_reduce(sums, group=self.tp_group)
w_norm, g_norm = sums.sqrt().chunk(2)
else:
# Fall back to vanilla Lamb
w_norm = torch.norm(p)
g_norm = torch.norm(update)
trust_ratio = torch.where(w_norm > 0 and g_norm > 0, (w_norm / g_norm), 1.0).item()
scaled_lr *= trust_ratio
p.data.add_(update, alpha=-scaled_lr)
return loss