mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||
|
||||
|
||||
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(
|
||||
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap
|
||||
linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap
|
||||
)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
@@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard = (
|
||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
out = linear(x_for_unshard)
|
||||
@@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
assert x_for_unshard.grad is not None
|
||||
target_unshard_gard = (
|
||||
x_for_unshard.grad
|
||||
if seq_parallel is False
|
||||
if seq_parallel_mode is None
|
||||
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(
|
||||
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel
|
||||
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
|
||||
)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
@@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_row(x_for_shard)
|
||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
@@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(
|
||||
linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap
|
||||
linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap
|
||||
)
|
||||
linear_row = Linear1D_Row.from_native_module(
|
||||
linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel
|
||||
linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode
|
||||
)
|
||||
|
||||
linear_1.load_state_dict(linear_col.state_dict())
|
||||
@@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard = (
|
||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||
shard_out = linear_row(linear_col(x_for_shard))
|
||||
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
target_out = (
|
||||
unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
assert_close(target_out, shard_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||
assert x_for_unshard.grad is not None
|
||||
target_unshard_gard = (
|
||||
x_for_unshard.grad
|
||||
if seq_parallel is False
|
||||
if seq_parallel_mode is None
|
||||
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
@parameterize("seq_parallel", [False, True])
|
||||
@parameterize("seq_parallel_mode", [None, "split_gather"])
|
||||
@parameterize("overlap", [True])
|
||||
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
||||
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
|
||||
def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
|
||||
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
||||
|
||||
|
||||
def check_dist_linear(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user