[Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wenxuan Tan
2024-09-10 12:06:50 +08:00
committed by GitHub
parent b3db1058ec
commit 8fd25d6e09
25 changed files with 527 additions and 1173 deletions

View File

@@ -100,7 +100,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "GPT2Model":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_output_hidden_state(
org_output,
sharded_output,
stage_manager,
atol=atol,
rtol=rtol,
shard_config=booster.plugin.shard_config,
)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@@ -132,14 +139,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"sp_size": 2,
"tp_size": 1,
"pp_size": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"sequence_parallelism_mode": "ring_attn",
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
"precision": "fp16",
"initial_scale": 1,
},
{
"sp_size": 2,
"tp_size": 2,
"pp_size": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"num_microbatches": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
@@ -148,7 +168,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
@@ -156,7 +176,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
@@ -185,7 +216,16 @@ def run_gpt2_test(test_config):
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and name != "transformers_gpt_lm":
# Only wrote zigzag splitting for cross entropy loss
continue
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config} for model {name}")
raise (e)
clear_layout_converter()
torch.cuda.empty_cache()
@@ -226,7 +266,11 @@ def run_gpt2_3d_test(test_config):
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config} for model {name}")
raise (e)
clear_layout_converter()
torch.cuda.empty_cache()