[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
This commit is contained in:
Hongxin Liu
2023-06-05 15:58:31 +08:00
committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
27 changed files with 738 additions and 525 deletions

View File

@@ -93,8 +93,7 @@ class CPUAdam(NVMeOptimizer):
bias_correction1,
bias_correction2,
use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
grad = grad.to(data.dtype)
if weight_decay != 0:
if use_adamw:
@@ -133,10 +132,12 @@ class CPUAdam(NVMeOptimizer):
if len(state) == 0:
state['step'] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state['step'] += 1
@@ -147,9 +148,17 @@ class CPUAdam(NVMeOptimizer):
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], div_scale)
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.adamw_mode)
else:
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"

View File

@@ -134,8 +134,8 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.')
g_l.append(p.grad.data)
p_l.append(p.data)

View File

@@ -1,16 +1,17 @@
from typing import Any, Optional
import torch
from torch.optim import Adam
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
from .nvme_optimizer import NVMeOptimizer
from .cpu_adam import CPUAdam
@OPTIMIZERS.register_module
class HybridAdam(NVMeOptimizer):
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
@@ -74,15 +75,9 @@ class HybridAdam(NVMeOptimizer):
nvme_offload_dir: Optional[str] = None,
**defaults: Any):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
# build during runtime if not found
cpu_optim = CPUAdamBuilder().load()
super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction,
nvme_offload_dir)
fused_optim = FusedOptimBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@@ -108,10 +103,12 @@ class HybridAdam(NVMeOptimizer):
if len(state) == 0:
state['step'] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state['step'] += 1
@@ -122,9 +119,17 @@ class HybridAdam(NVMeOptimizer):
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], div_scale)
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.adamw_mode)
else:
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':