[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)

* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384)

* [sequence parallel] add sequence parallel linear col/row support (#4336)

* add sequence parallel linear col/row support

* add annotation

* add annotation

* add support for gpt2 fused qkv linear layer

* support sequence parallel in GPT2

* add docstring and note

* add requirments

* remove unused flash-attb

* modify flash attn test

* modify flash attn setting

* modify flash attn code

* add assert before divide, rename forward function

* [shardformer/test] fix gpt2 test with seq-parallel

* [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize

* [shardformer/sequence parallel] polish code
This commit is contained in:
Bin Jia
2023-08-16 15:41:20 +08:00
committed by GitHub
parent d20dceb9a3
commit 424629fea0
12 changed files with 655 additions and 65 deletions

View File

@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
n_fused=3)
assert linear.weight.shape == torch.Size([48, 192])
@@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear.load_state_dict(linear_conv_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
# check backward correctness
out.sum().backward()
@@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert_close(target_grad, linear_conv_col.weight.grad)
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
@@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
@@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool):
assert_close(target_grad, linear_row.weight.grad)
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel)
check_linear_conv_1d_row(lazy_init, seq_parallel)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()
check_gpt2_qkv_fused_linear_1d()
@rerun_if_address_is_in_use()

View File

@@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize('lazy_init', [False, True])
def check_linear_1d_col(lazy_init: bool):
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
linear_col = Linear1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
overlap=overlap)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
linear_col.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
def check_linear_1d_row(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_row = Linear1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
assert_close(out, gather_out)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
def check_linear_col_plus_row(lazy_init: bool):
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
with ctx:
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
process_group=None,
gather_output=False,
seq_parallel=seq_parallel,
overlap=overlap)
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
process_group=None,
parallel_input=True,
seq_parallel=seq_parallel)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
assert_close(unshard_out, shard_out)
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
def run_dist(rank, world_size, port):
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [False, True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
def check_dist_linear(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
check_linear_1d_row()
check_linear_col_plus_row()
run_dist_linear_test()
@rerun_if_address_is_in_use()
def test_linear():
spawn(run_dist, nprocs=2)
spawn(check_dist_linear, nprocs=2)
if __name__ == '__main__':

View File

@@ -1,4 +1,5 @@
import copy
import math
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
@@ -25,6 +26,7 @@ def build_model(model_fn,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
@@ -38,7 +40,8 @@ def build_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
@@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
return loss
data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():

View File

@@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': False,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):