mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)
* [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test
This commit is contained in:
@@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["gpt2"])
|
||||
def exam_grad_clipping(placement_config, model_name: str):
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
||||
set_seed(1912)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
chunk_config_dict=config_dict,
|
||||
chunk_init_device=init_device,
|
||||
pin_memory=True,
|
||||
master_weights=master_weights,
|
||||
**placement_config,
|
||||
)
|
||||
|
||||
@@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
|
||||
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
|
||||
# as no master weights leads to error accumulation, we don't check the loss
|
||||
if master_weights:
|
||||
assert_close(torch_loss, loss)
|
||||
|
||||
import apex.amp as apex_amp
|
||||
|
||||
@@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
torch_optim.step()
|
||||
zero_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
if master_weights:
|
||||
check_param(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user