[shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
Zhongkai Zhao
2024-04-03 17:15:47 +08:00
committed by GitHub
parent 7e0ec5a85c
commit 8e412a548e
33 changed files with 1630 additions and 256 deletions

View File

@@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap
linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap
)
# ensure that the parameters are distributed
@@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
@@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
assert x_for_unshard.grad is not None
target_unshard_gard = (
x_for_unshard.grad
if seq_parallel is False
if seq_parallel_mode is None
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_unshard_gard, x_for_shard.grad)
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
)
assert linear_row.weight.shape == torch.Size([128, 16])
@@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
@@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
@@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(
linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap
linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap
)
linear_row = Linear1D_Row.from_native_module(
linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel
linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode
)
linear_1.load_state_dict(linear_col.state_dict())
@@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
target_out = (
unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_out, shard_out)
# check backward correctness
@@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
assert x_for_unshard.grad is not None
target_unshard_gard = (
x_for_unshard.grad
if seq_parallel is False
if seq_parallel_mode is None
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel", [False, True])
@parameterize("seq_parallel_mode", [None, "split_gather"])
@parameterize("overlap", [True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_1d_row(lazy_init, seq_parallel_mode)
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
def check_dist_linear(rank, world_size, port):