mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 04:05:35 +00:00
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
from typing import Dict, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
class OptimizerWrapper:
|
|
"""
|
|
A standard interface for optimizers wrapped by the Booster.
|
|
|
|
Args:
|
|
optim (Optimizer): The optimizer to be wrapped.
|
|
"""
|
|
|
|
def __init__(self, optim: Optimizer):
|
|
self.optim = optim
|
|
|
|
@property
|
|
def parameters(self):
|
|
params = []
|
|
|
|
for group in self.param_groups:
|
|
params += group["params"]
|
|
return params
|
|
|
|
@property
|
|
def param_groups(self):
|
|
return self.optim.param_groups
|
|
|
|
@property
|
|
def defaults(self):
|
|
return self.optim.defaults
|
|
|
|
def add_param_group(self, *args, **kwargs):
|
|
return self.optim.add_param_group(*args, **kwargs)
|
|
|
|
def step(self, *args, **kwargs):
|
|
"""
|
|
Performs a single optimization step.
|
|
"""
|
|
return self.optim.step(*args, **kwargs)
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
"""
|
|
Clears the gradients of all optimized `torch.Tensor`.
|
|
"""
|
|
self.optim.zero_grad(*args, **kwargs)
|
|
|
|
def backward(self, loss: Tensor, *args, **kwargs):
|
|
"""
|
|
Performs a backward pass on the loss.
|
|
"""
|
|
loss.backward(*args, **kwargs)
|
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
|
torch.autograd.backward(tensor, grad)
|
|
|
|
def state_dict(self):
|
|
"""
|
|
Returns the optimizer state.
|
|
"""
|
|
return self.optim.state_dict()
|
|
|
|
def load_state_dict(self, *args, **kwargs):
|
|
"""
|
|
Loads the optimizer state.
|
|
"""
|
|
self.optim.load_state_dict(*args, **kwargs)
|
|
|
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
|
"""
|
|
Clips gradient of an iterable of parameters at specified min and max values.
|
|
|
|
Args:
|
|
clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range
|
|
|
|
Note:
|
|
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the
|
|
faster implementation. Please refer to the PyTorch documentation for more details.
|
|
"""
|
|
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
|
|
|
|
def clip_grad_by_norm(
|
|
self,
|
|
max_norm: Union[float, int],
|
|
norm_type: Union[float, int] = 2.0,
|
|
error_if_nonfinite: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
) -> Tensor:
|
|
"""
|
|
Clips gradient norm of an iterable of parameters.
|
|
|
|
Args:
|
|
max_norm (float or int): max norm of the gradients
|
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
|
error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False
|
|
|
|
Note:
|
|
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the
|
|
faster implementation. Please refer to the PyTorch documentation for more details.
|
|
"""
|
|
norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
|
return norm
|
|
|
|
def scale_loss(self, loss: Tensor):
|
|
"""
|
|
Scales the loss for mixed precision training.
|
|
|
|
Note: Only available for optimizers with mixed precision training.
|
|
|
|
Args:
|
|
loss (Tensor): The loss to be scaled.
|
|
"""
|
|
raise NotImplementedError(
|
|
"The method scale_loss is only available for optimizers with mixed precision training"
|
|
)
|
|
|
|
def unscale_grad(self):
|
|
"""
|
|
Unscale the gradients for mixed precision training.
|
|
|
|
Note: Only available for optimizers with mixed precision training.
|
|
"""
|
|
raise NotImplementedError(
|
|
"The method unscale_grad is only available for optimizers with mixed precision training"
|
|
)
|
|
|
|
def unwrap(self):
|
|
"""
|
|
Unwrap the optimizer for checkpoint saving/loading.
|
|
"""
|
|
return self.optim
|
|
|
|
|
|
class DistributedOptim(Optimizer):
|
|
def setup_distributed(
|
|
self,
|
|
tp_group: Optional[dist.ProcessGroup] = None,
|
|
dp_group: Optional[dist.ProcessGroup] = None,
|
|
shard_to_working_param: Optional[Dict] = {},
|
|
padding_map: Optional[Dict] = None,
|
|
is_zero: Optional[bool] = False,
|
|
):
|
|
"""Assign process groups for TP and ZeRO 2.
|
|
Arguments:
|
|
tp_group (dist.ProcessGroup): Tensor Parallel process group
|
|
dp_group (dist.ProcessGroup): ZeRO stage 2 process group
|
|
shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape.
|
|
This maps from id(view) to model params used in forward & backward.
|
|
padding_map (Dict): Per-param padding from ZeRO stage 2
|
|
is_zero (bool): Whether to use ZeRO stage 2.
|
|
"""
|
|
|
|
raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!")
|