[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
This commit is contained in:
Hongxin Liu
2023-06-05 15:58:31 +08:00
committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
27 changed files with 738 additions and 525 deletions

View File

@@ -0,0 +1,9 @@
from .base import MixedPrecisionMixin
from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin
__all__ = [
'MixedPrecisionMixin',
'FP16MixedPrecisionMixin',
'BF16MixedPrecisionMixin',
]

View File

@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
import torch
from torch import Tensor
class MixedPrecisionMixin(ABC):
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
Attributes:
dtype (torc.dtype): The expected dtype of the gradients.
Examples:
```python
class MyMixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer):
super().__init__(optim)
self.mixed_precision = MixedPrecisionMixin()
def backward(self, loss):
loss = self.mixed_precision.pre_backward(loss)
loss.backward()
def backward_by_grad(self, tensor, grad):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def step(self):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
div_scale = self.mixed_precision.get_grad_div_scale()
# maybe clip grad here
# maybe scale grad here
self.optim.step()
def zero_grad(self):
self.mixed_precision.pre_zero_grad()
return self.optim.zero_grad()
```
"""
dtype: torch.dtype
@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
"""Called before backward.
Args:
loss (Tensor): Loss value.
Returns:
Tensor: Loss value (possibly scaled).
"""
pass
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
"""Called before backward by grad. This is helpful for pipeline parallelism.
Args:
tensor (Tensor): Tensor to backward.
grad (Tensor): Gradient of the tensor.
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
pass
@abstractmethod
def should_skip_step(self) -> bool:
"""Called before step.
Returns:
bool: Whether to skip the step.
"""
pass
@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad.
"""
pass
@abstractmethod
def get_grad_div_scale(self) -> float:
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
Returns:
float: A divisor for gradient clipping or step.
"""
pass

View File

@@ -0,0 +1,23 @@
import torch
from torch import Tensor
from .base import MixedPrecisionMixin
class BF16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.bfloat16
def pre_backward(self, loss: Tensor) -> Tensor:
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
return grad
def should_skip_step(self) -> bool:
return False
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
return 1.0

View File

@@ -0,0 +1,84 @@
from abc import abstractmethod
from enum import Enum
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
def __init__(self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
@property
def loss_scale(self) -> float:
return self.grad_scaler.scale.item()
@abstractmethod
def check_local_overflow(self) -> bool:
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.
Returns:
bool: Whether there is overflow in the local process.
"""
pass
def check_overflow(self) -> bool:
# clear previous overflow record
self.found_overflow.fill_(0.0)
if self.check_local_overflow():
self.found_overflow.fill_(1.0)
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
return self.found_overflow.item() > 0
def pre_backward(self, loss: Tensor) -> Tensor:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
self.optim_state = OptimState.SCALED
return grad
def should_skip_step(self) -> bool:
found_inf = self.check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED
return found_inf
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
self.optim_state = OptimState.UNSCALED
return self.loss_scale