[zero] solve hang

This commit is contained in:
hxwang
2024-07-05 07:19:37 +00:00
committed by Hongxin Liu
parent 0fad23c691
commit 46c069b0db
12 changed files with 113 additions and 390 deletions

View File

@@ -25,13 +25,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
)
with torch.autograd.set_detect_anomaly(True):
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@@ -73,6 +74,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check grads
check_all_grad_tensors(grads_to_check)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
@@ -103,9 +107,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@@ -114,37 +115,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
{
"tp_size": 1,
"pp_size": 4,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"num_microbatches": 4,
"zero_stage": 0,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
# {
"precision": "fp32",
}, # pp + ep
# {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "fp16"}, # full dp for moe and non-moe
# { # moe_dp = 2, non_moe_dp = 4
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "enable_all_optimization": True,
# "use_lazy_init": False,
# "precision": "fp16",
# "initial_scale": 1,
# },
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "num_microbatches": 2,
# "zero_stage": 2,
# "enable_all_optimization": True,
# "use_lazy_init": False,
# "precision": "fp16",
# "initial_scale": 1,
# },
# }, # moe_dp = 1, non_moe_dp = 4
# {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp16"},
# {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
],
)
def run_mixtral_test(test_config):