mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -21,23 +21,40 @@ TEST_MODELS = ['gpt2']
|
||||
# these models are too small, all parameters in these models are compacted into one chunk
|
||||
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
|
||||
|
||||
# bfloat16 cannot represent them exactly
|
||||
BF16_IGNORED_KEYS = [
|
||||
'albert.embeddings.word_embeddings.weight',
|
||||
'albert.embeddings.position_embeddings.weight',
|
||||
'masked_bias',
|
||||
]
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device)
|
||||
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
|
||||
continue
|
||||
rtol, atol = 1e-3, 4e-3
|
||||
if dtype is torch.bfloat16:
|
||||
rtol, atol = 4e-3, 8e-3
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||
assert_close(value.float(),
|
||||
temp_zero_value.float(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('model_name', TEST_MODELS)
|
||||
def exam_model_step(placement_policy, model_name: str):
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str):
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||
@@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str):
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
rtol, atol = 1e-4, 1e-5
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
@@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str):
|
||||
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('model_name', EXAMPLE_MODELS)
|
||||
def exam_tiny_example(placement_policy, model_name: str):
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(2008)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
@@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
rtol, atol = 1.5e-6, 2e-5
|
||||
if mixed_precision is torch.bfloat16:
|
||||
rtol, atol = 2e-3, 2e-3
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
@@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12
|
||||
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user