mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
91
colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Normal file
91
colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class MixedPrecisionMixin(ABC):
|
||||
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
|
||||
|
||||
Attributes:
|
||||
dtype (torc.dtype): The expected dtype of the gradients.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
class MyMixedPrecisionOptimizer(OptimizerWrapper):
|
||||
def __init__(self, optim: Optimizer):
|
||||
super().__init__(optim)
|
||||
self.mixed_precision = MixedPrecisionMixin()
|
||||
|
||||
def backward(self, loss):
|
||||
loss = self.mixed_precision.pre_backward(loss)
|
||||
loss.backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||
tensor.backward(grad)
|
||||
|
||||
def step(self):
|
||||
if self.mixed_precision.should_skip_step():
|
||||
self.zero_grad()
|
||||
return
|
||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||
# maybe clip grad here
|
||||
# maybe scale grad here
|
||||
self.optim.step()
|
||||
|
||||
def zero_grad(self):
|
||||
self.mixed_precision.pre_zero_grad()
|
||||
return self.optim.zero_grad()
|
||||
```
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||
"""Called before backward.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Loss value.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss value (possibly scaled).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||
"""Called before backward by grad. This is helpful for pipeline parallelism.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Tensor to backward.
|
||||
grad (Tensor): Gradient of the tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Gradient of the tensor (possibly scaled).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def should_skip_step(self) -> bool:
|
||||
"""Called before step.
|
||||
|
||||
Returns:
|
||||
bool: Whether to skip the step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_zero_grad(self) -> None:
|
||||
"""Called before zero_grad.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_grad_div_scale(self) -> float:
|
||||
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
|
||||
|
||||
Returns:
|
||||
float: A divisor for gradient clipping or step.
|
||||
"""
|
||||
pass
|
Reference in New Issue
Block a user