ColossalAI/colossalai/interface/optimizer.py
Edenzzzz 43995ee436
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [optim] add distributed came (#5526)

* test CAME under LowLevelZeroOptimizer wrapper

* test CAME TP row and col pass

* test CAME zero pass

* came zero add master and worker param id convert

* came zero test pass

* came zero test pass

* test distributed came passed

* reform code, Modify some expressions and add comments

* minor fix of test came

* minor fix of dist_came and test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix of dist_came and test

* rebase dist-optim

* rebase dist-optim

* fix remaining comments

* add test dist came using booster api

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [optim] Distributed Adafactor (#5484)

* [feature] solve conflict; update optimizer readme;

* [feature] update optimize readme;

* [fix] fix testcase;

* [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel);

* [feature] Add transformers_bert model zoo in testcase;

* [feature] add user documentation to docs/source/feature.

* [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam;

* [feature] modify user documentation;

* [fix] fix readme format issue;

* [fix] add zero=0 in testcase; cached augment in dict;

* [fix] fix percision issue;

* [feature] add distributed rms;

* [feature] remove useless comment in testcase;

* [fix] Remove useless test; open zero test; remove fp16 test in bert exam;

* [feature] Extract distributed rms function;

* [feature] add booster + lowlevelzeroPlugin in test;

* [feature] add Start_with_booster_API case in md; add Supporting Information in md;

* [fix] Also remove state movement in base adafactor;

* [feature] extract factor function;

* [feature] add LowLevelZeroPlugin test;

* [fix] add tp=False and zero=True in logic;

* [fix] fix use zero logic;

* [feature] add row residue logic in column parallel factor;

* [feature] add check optim state func;

* [feature] Remove duplicate logic;

* [feature] update optim state check func and percision test bug;

* [fix] update/fix optim state; Still exist percision issue;

* [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info;

* [feature] removed print & comments in utils;

* [feature] uodate Readme;

* [feature] add LowLevelZeroPlugin test with Bert model zoo;

* [fix] fix logic in _rms;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] remove comments in testcase;

* [feature] add zh-Han Readme;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676)

* [feature] daily update;

* [fix] fix dist came;

* [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo;

* [fix] open rms; fix low level zero test; fix dist came test function name;

* [fix] remove redundant test;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

* update comments

* add initial distributed galore

* add initial distributed galore

* add galore set param utils; change setup_distributed interface

* projected grad precision passed

* basic precision tests passed

* tests passed; located svd precision issue in fwd-bwd; banned these tests

* Plugin DP + TP tests passed

* move get_shard_dim to d_tensor

* add comments

* remove useless files

* remove useless files

* fix zero typo

* improve interface

* remove moe changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* fix deepcopy

* update came & adafactor to main

* fix param map

* fix typo

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692)


Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 13:52:45 +08:00

159 lines
5.0 KiB
Python

from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
class OptimizerWrapper:
"""
A standard interface for optimizers wrapped by the Booster.
Args:
optim (Optimizer): The optimizer to be wrapped.
"""
def __init__(self, optim: Optimizer):
self.optim = optim
@property
def parameters(self):
params = []
for group in self.param_groups:
params += group["params"]
return params
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults
def add_param_group(self, *args, **kwargs):
return self.optim.add_param_group(*args, **kwargs)
def step(self, *args, **kwargs):
"""
Performs a single optimization step.
"""
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
"""
Clears the gradients of all optimized `torch.Tensor`.
"""
self.optim.zero_grad(*args, **kwargs)
def backward(self, loss: Tensor, *args, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def state_dict(self):
"""
Returns the optimizer state.
"""
return self.optim.state_dict()
def load_state_dict(self, *args, **kwargs):
"""
Loads the optimizer state.
"""
self.optim.load_state_dict(*args, **kwargs)
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
"""
Clips gradient of an iterable of parameters at specified min and max values.
Args:
clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range
Note:
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the
faster implementation. Please refer to the PyTorch documentation for more details.
"""
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> Tensor:
"""
Clips gradient norm of an iterable of parameters.
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False
Note:
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the
faster implementation. Please refer to the PyTorch documentation for more details.
"""
norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
return norm
def scale_loss(self, loss: Tensor):
"""
Scales the loss for mixed precision training.
Note: Only available for optimizers with mixed precision training.
Args:
loss (Tensor): The loss to be scaled.
"""
raise NotImplementedError(
"The method scale_loss is only available for optimizers with mixed precision training"
)
def unscale_grad(self):
"""
Unscale the gradients for mixed precision training.
Note: Only available for optimizers with mixed precision training.
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training"
)
def unwrap(self):
"""
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim
class DistributedOptim(Optimizer):
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map: Optional[Dict] = None,
is_zero: Optional[bool] = False,
):
"""Assign process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO stage 2 process group
shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape.
This maps from id(view) to model params used in forward & backward.
padding_map (Dict): Per-param padding from ZeRO stage 2
is_zero (bool): Whether to use ZeRO stage 2.
"""
raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!")