mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_conv_1d_col(lazy_init: bool):
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
@@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
seq_parallel=seq_parallel,
|
||||
n_fused=3)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
@@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||
linear.load_state_dict(linear_conv_col.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
x = torch.rand(1, 4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_conv_col(x)
|
||||
assert_close(rearrange(out, 1), gather_out)
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
gather_out = linear_conv_col(x_for_shard)
|
||||
assert_close(rearrange(out, -1), gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
@@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_conv_1d_row(lazy_init: bool):
|
||||
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
|
||||
process_group=None,
|
||||
parallel_input=False,
|
||||
seq_parallel=seq_parallel)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
@@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||
linear.load_state_dict(linear_row.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
x = torch.rand(1, 4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
assert_close(out, gather_out)
|
||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
@@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
@parameterize('seq_parallel', [False, True])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# test for linear conv
|
||||
check_linear_conv_1d_col()
|
||||
check_linear_conv_1d_row()
|
||||
check_gpt2_qkv_fused_linear_1d()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_col(lazy_init: bool):
|
||||
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
|
||||
linear_col = Linear1D_Col.from_native_module(linear_copy,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
seq_parallel=seq_parallel,
|
||||
overlap=overlap)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
assert is_distributed_tensor(linear_col.weight)
|
||||
@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
|
||||
linear_col.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
out = linear(x_for_unshard)
|
||||
@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
|
||||
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_row(lazy_init: bool):
|
||||
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_copy,
|
||||
process_group=None,
|
||||
parallel_input=False,
|
||||
seq_parallel=seq_parallel)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
assert linear_row.bias.shape == torch.Size([128])
|
||||
@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
|
||||
linear_row.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_row(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_col_plus_row(lazy_init: bool):
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
|
||||
with ctx:
|
||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
|
||||
process_group=None,
|
||||
gather_output=False,
|
||||
seq_parallel=seq_parallel,
|
||||
overlap=overlap)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
|
||||
process_group=None,
|
||||
parallel_input=True,
|
||||
seq_parallel=seq_parallel)
|
||||
|
||||
linear_1.load_state_dict(linear_col.state_dict())
|
||||
linear_col.load_state_dict(linear_1.state_dict())
|
||||
@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
|
||||
linear_row.load_state_dict(linear_2.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||
shard_out = linear_row(linear_col(x_for_shard))
|
||||
assert_close(unshard_out, shard_out)
|
||||
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_out, shard_out)
|
||||
|
||||
# check backward correctness
|
||||
unshard_out.sum().backward()
|
||||
@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
|
||||
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@parameterize('lazy_init', [False, True])
|
||||
@parameterize('seq_parallel', [False, True])
|
||||
@parameterize('overlap', [False, True])
|
||||
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
||||
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
|
||||
|
||||
|
||||
def check_dist_linear(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_1d_col()
|
||||
check_linear_1d_row()
|
||||
check_linear_col_plus_row()
|
||||
run_dist_linear_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear():
|
||||
spawn(run_dist, nprocs=2)
|
||||
spawn(check_dist_linear, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -25,6 +26,7 @@ def build_model(model_fn,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
use_lazy_init: bool = False):
|
||||
# create new model
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
@@ -38,7 +40,8 @@ def build_model(model_fn,
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
@@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
return loss
|
||||
|
||||
data = data_gen_fn()
|
||||
|
||||
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
||||
seq_len = data['input_ids'].shape[1]
|
||||
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
||||
times = lcm // seq_len
|
||||
input_shape = data['input_ids'].shape
|
||||
for k, v in data.items():
|
||||
if v.shape == input_shape:
|
||||
data[k] = v.repeat(1, times)
|
||||
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
|
@@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': True,
|
||||
'enable_sequence_parallelism': True,
|
||||
'precision': 'fp32',
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(test_config):
|
||||
|
Reference in New Issue
Block a user