mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)
* [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test
This commit is contained in:
@@ -13,9 +13,7 @@ from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
_FUSED_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.half),
|
||||
(torch.float, torch.float),
|
||||
(torch.half, torch.float),
|
||||
(torch.half, torch.half),
|
||||
(torch.bfloat16, torch.float),
|
||||
(torch.float, torch.bfloat16),
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
]
|
||||
@@ -23,7 +21,6 @@ _FUSED_ALLOWED_P_G_TYPES = [
|
||||
_CPU_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.half),
|
||||
(torch.float, torch.float),
|
||||
(torch.half, torch.float),
|
||||
(torch.half, torch.half),
|
||||
]
|
||||
|
||||
@@ -138,8 +135,8 @@ def check_adam_kernel(
|
||||
master_exp_avg_sq = torch.zeros_like(master_p)
|
||||
p = master_p.clone().to(p_dtype)
|
||||
g = master_g.clone().to(g_dtype)
|
||||
exp_avg = master_exp_avg.clone()
|
||||
exp_avg_sq = master_exp_avg_sq.clone()
|
||||
exp_avg = master_exp_avg.clone().to(p_dtype)
|
||||
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
|
||||
|
||||
for step in range(1, 1 + n_steps):
|
||||
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
|
||||
|
@@ -21,8 +21,6 @@ _ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
|
||||
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
|
||||
]
|
||||
|
||||
N_STEPS = 3
|
||||
|
@@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["gpt2"])
|
||||
def exam_grad_clipping(placement_config, model_name: str):
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
||||
set_seed(1912)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
chunk_config_dict=config_dict,
|
||||
chunk_init_device=init_device,
|
||||
pin_memory=True,
|
||||
master_weights=master_weights,
|
||||
**placement_config,
|
||||
)
|
||||
|
||||
@@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
|
||||
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
|
||||
# as no master weights leads to error accumulation, we don't check the loss
|
||||
if master_weights:
|
||||
assert_close(torch_loss, loss)
|
||||
|
||||
import apex.amp as apex_amp
|
||||
|
||||
@@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
torch_optim.step()
|
||||
zero_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
if master_weights:
|
||||
check_param(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", TEST_MODELS)
|
||||
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
# apex no master weights leads to nan, so we don't use it
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
@@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = False
|
||||
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
|
||||
model = GeminiDDP(
|
||||
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
|
||||
@@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
|
||||
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
|
||||
# as no master weights leads to error accumulation, we don't check the loss
|
||||
if master_weights:
|
||||
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
if master_weights:
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
|
Reference in New Issue
Block a user