[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
This commit is contained in:
Hongxin Liu
2023-06-05 15:58:31 +08:00
committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
27 changed files with 738 additions and 525 deletions

View File

@@ -82,7 +82,6 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
# create models
@@ -101,7 +100,6 @@ def exam_zero_1_grad_acc():
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
clip_grad_norm=1.0)
@@ -128,9 +126,8 @@ def exam_zero_1_grad_acc():
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
unscale_grad = z1p.grad / grad_scale
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, unscale_grad)
assert torch.equal(p.grad, z1p.grad)
zero_optimizer._sync_grad()

View File

@@ -7,7 +7,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
@@ -25,15 +25,18 @@ class MlpModel(nn.Module):
return x
def half_close(a, b, loose=False):
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if loose:
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().half()
b = b.detach().half()
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
@@ -96,7 +99,8 @@ def exam_zero_1_2():
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_torch_ddp():
@parameterize('dtype', [torch.float16, torch.bfloat16])
def exam_zero_1_torch_ddp(dtype: torch.dtype):
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
@@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp():
seed_all(1453)
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# half_close(p.data, z1p.data)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
# create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
@@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp():
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
zero_output = zero_model(input_data.to(dtype))
# torch-ddp forward
torch_output = torch_model(input_data)
half_close(zero_output, torch_output, loose=True)
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
@@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp():
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
half_close(p.grad, z1p.grad, loose=True)
loose_close(p.grad, z1p.grad, dtype=dtype)
# zero-dp step
zero_optimizer._sync_grad()
@@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp():
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
half_close(p.data, z1p.data, loose=True)
loose_close(p.data, z1p.data, dtype=dtype)
def run_dist(rank, world_size, port):