[Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edenzzzz
2024-07-09 18:05:20 +08:00
committed by GitHub
parent 66abf1c6e8
commit fbf33ecd01
12 changed files with 148 additions and 323 deletions

View File

@@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.pipeline_stage_manager is None
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer.master_to_working_param[id(p2)]
working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0
@@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Test ring + Flash attention
"tp_size": 2,
"pp_size": 1,
@@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
@@ -245,7 +247,6 @@ def run_llama_test(test_config):
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()