[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test
This commit is contained in:
Hongxin Liu
2023-10-16 21:56:53 +08:00
committed by GitHub
parent 7768afbad0
commit 4f68b3f10c
8 changed files with 148 additions and 136 deletions

View File

@@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
def exam_grad_clipping(placement_config, model_name: str):
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
master_weights=master_weights,
**placement_config,
)
@@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str):
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
assert_close(torch_loss, loss)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)
import apex.amp as apex_amp
@@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str):
torch_optim.step()
zero_optim.step()
check_param(model, torch_model)
if master_weights:
check_param(model, torch_model)
def run_dist(rank, world_size, port):