mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -18,8 +18,23 @@ def data_gen():
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
# input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
# attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
@@ -35,9 +50,9 @@ def data_gen_for_question_answering():
|
||||
# question answering data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
start_positions = torch.tensor([[0], [0]], dtype=torch.int64)
|
||||
data["start_positions"] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
end_positions = torch.tensor([[1], [1]], dtype=torch.int64)
|
||||
data["end_positions"] = end_positions
|
||||
return data
|
||||
|
||||
@@ -46,14 +61,20 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -61,12 +82,18 @@ def date_gen_for_double_heads():
|
||||
num_choices = 2
|
||||
batch_size = 2
|
||||
input_ids = torch.tensor(
|
||||
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
|
||||
[
|
||||
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
||||
|
||||
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
||||
mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
|
||||
mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
|
||||
@@ -103,6 +130,7 @@ config = transformers.GPT2Config(
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification",
|
||||
pad_token_id=50256,
|
||||
tie_word_embeddings=True,
|
||||
)
|
||||
|
||||
config_for_token_classification = copy.deepcopy(config)
|
||||
|
@@ -28,9 +28,19 @@ if HAS_LLAMA:
|
||||
# -----------------------------------
|
||||
|
||||
input_ids = torch.Tensor(
|
||||
[[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]]
|
||||
[
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
]
|
||||
).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||
|
||||
attention_mask = torch.Tensor(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
).long()
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
|
@@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
|
||||
enable_flash_attention = True if tp_size > 1 else False
|
||||
enable_fused_normalization = True if tp_size > 1 else False
|
||||
enable_jit_fused = True if tp_size > 1 else False
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
pretrained_path = os.path.join(tempdir, "pretrained")
|
||||
@@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
tp_size=tp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
extra_dp_size=extra_dp_size,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
@@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
enable_flash_attention = True if tp_size > 1 else False
|
||||
enable_fused_normalization = True if tp_size > 1 else False
|
||||
enable_jit_fused = True if tp_size > 1 else False
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
@@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
initial_scale=(2**14),
|
||||
tp_size=tp_size,
|
||||
extra_dp_size=extra_dp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
|
@@ -84,6 +84,30 @@ def check_process_group_mesh_with_cases():
|
||||
2: [2],
|
||||
3: [3],
|
||||
}
|
||||
TPxPP_RANKS_IN_GROUP = {
|
||||
0: [0, 1, 2, 3],
|
||||
1: [0, 1, 2, 3],
|
||||
2: [0, 1, 2, 3],
|
||||
3: [0, 1, 2, 3],
|
||||
}
|
||||
DPxTP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
TPxPP_PARTIAL_INDICES = {
|
||||
0: [[0, 1], [0]],
|
||||
1: [[1], [0, 1]],
|
||||
2: [[0], [0, 1]],
|
||||
3: [[0, 1], [1]],
|
||||
}
|
||||
TPxPP_RANKS_IN_GROUP_PARTIAL = {
|
||||
0: [0, 1],
|
||||
1: [1, 3],
|
||||
2: [0, 2],
|
||||
3: [2, 3],
|
||||
}
|
||||
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)
|
||||
|
||||
@@ -107,6 +131,12 @@ def check_process_group_mesh_with_cases():
|
||||
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
|
||||
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
|
||||
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]
|
||||
dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM])
|
||||
assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank]
|
||||
tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM])
|
||||
assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank]
|
||||
tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank])
|
||||
assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank]
|
||||
|
||||
# check prev rank
|
||||
if RANK_TO_COORDINATE[rank][TP_DIM] != 0:
|
||||
|
@@ -56,13 +56,18 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
|
||||
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap
|
||||
linear_copy,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
seq_parallel_mode=seq_parallel_mode,
|
||||
n_fused=3,
|
||||
overlap=overlap,
|
||||
)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
@@ -79,7 +84,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool)
|
||||
# check computation correctness
|
||||
x = torch.rand(1, 4, 48).cuda()
|
||||
out = linear(x)
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard = (
|
||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
gather_out = linear_conv_col(x_for_shard)
|
||||
assert_close(rearrange(out, -1), gather_out)
|
||||
|
||||
@@ -91,14 +98,14 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(
|
||||
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel
|
||||
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
|
||||
)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
@@ -115,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
x = torch.rand(1, 4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -128,11 +135,11 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
@parameterize("seq_parallel", [False, True])
|
||||
@parameterize("seq_parallel_mode", ["split_gather", None])
|
||||
@parameterize("overlap", [True])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||
|
||||
|
||||
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(
|
||||
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap
|
||||
linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap
|
||||
)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
@@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard = (
|
||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
out = linear(x_for_unshard)
|
||||
@@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
assert x_for_unshard.grad is not None
|
||||
target_unshard_gard = (
|
||||
x_for_unshard.grad
|
||||
if seq_parallel is False
|
||||
if seq_parallel_mode is None
|
||||
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(
|
||||
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel
|
||||
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
|
||||
)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
@@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_row(x_for_shard)
|
||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
assert_close(target_out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
@@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(
|
||||
linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap
|
||||
linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap
|
||||
)
|
||||
linear_row = Linear1D_Row.from_native_module(
|
||||
linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel
|
||||
linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode
|
||||
)
|
||||
|
||||
linear_1.load_state_dict(linear_col.state_dict())
|
||||
@@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
x_for_shard = (
|
||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||
shard_out = linear_row(linear_col(x_for_shard))
|
||||
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
target_out = (
|
||||
unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
assert_close(target_out, shard_out)
|
||||
|
||||
# check backward correctness
|
||||
@@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||
assert x_for_unshard.grad is not None
|
||||
target_unshard_gard = (
|
||||
x_for_unshard.grad
|
||||
if seq_parallel is False
|
||||
if seq_parallel_mode is None
|
||||
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||
)
|
||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
@parameterize("seq_parallel", [False, True])
|
||||
@parameterize("seq_parallel_mode", [None, "split_gather"])
|
||||
@parameterize("overlap", [True])
|
||||
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
||||
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
|
||||
def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
|
||||
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
||||
|
||||
|
||||
def check_dist_linear(rank, world_size, port):
|
||||
|
178
tests/test_shardformer/test_layer/test_sequence_parallel.py
Normal file
178
tests/test_shardformer/test_layer/test_sequence_parallel.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import all_to_all_comm
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
class SequenceParallelAttention(torch.nn.Module):
|
||||
"""Initialization.
|
||||
|
||||
Arguments:
|
||||
local_attention (Module): local attention with q,k,v
|
||||
sequence_process_group (ProcessGroup): sequence parallel process group
|
||||
scatter_idx (int): scatter_idx for all2all comm
|
||||
gather_idx (int): gather_idx for all2all comm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
heads_num: torch.Tensor,
|
||||
hidden_dim: torch.Tensor,
|
||||
enable_sequence_parallellism: bool = False,
|
||||
sequence_process_group: dist.ProcessGroup = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
) -> None:
|
||||
super(SequenceParallelAttention, self).__init__()
|
||||
self.spg = sequence_process_group
|
||||
self.scatter_idx = scatter_idx
|
||||
self.gather_idx = gather_idx
|
||||
self.heads_num = heads_num
|
||||
self.hidden_dim = hidden_dim
|
||||
assert hidden_dim % heads_num == 0
|
||||
self.head_dim = hidden_dim // heads_num
|
||||
self.enable_sequence_parallellism = enable_sequence_parallellism
|
||||
|
||||
self.q = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.k = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.v = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.out = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def attn(self, q, k, v):
|
||||
batch_size, seq_len = q.shape[0], q.shape[1]
|
||||
|
||||
scale = self.head_dim**0.5
|
||||
qk = torch.matmul(q, k.transpose(-2, -1)) / scale
|
||||
weights = F.softmax(qk, dim=-1)
|
||||
|
||||
attention_score = torch.matmul(weights, v)
|
||||
|
||||
return attention_score
|
||||
|
||||
def forward(self, x) -> Tensor:
|
||||
bsz, q_len, _ = x.size()
|
||||
|
||||
seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len
|
||||
num_heads = (
|
||||
self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num
|
||||
)
|
||||
|
||||
# in shape : e.g., [s/p:h:]
|
||||
query_states = self.q(x)
|
||||
key_states = self.k(x)
|
||||
value_states = self.v(x)
|
||||
|
||||
if self.enable_sequence_parallellism:
|
||||
query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx)
|
||||
key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx)
|
||||
value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx)
|
||||
|
||||
query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
# out shape : e.g., [s:h/p:]
|
||||
attn_score = self.attn(query_states, key_states, value_states)
|
||||
attn_score = attn_score.transpose(1, 2).contiguous()
|
||||
attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim)
|
||||
if self.enable_sequence_parallellism:
|
||||
attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx)
|
||||
|
||||
# output e.g., [s/p::h]
|
||||
output = self.out(attn_score)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
|
||||
seq_len = seq_len
|
||||
hidden_dim = hidden_dim
|
||||
head_num = head_num
|
||||
batch_size = batch_size
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
|
||||
x_unshard = x.clone()
|
||||
x_unshard.requires_grad_(True)
|
||||
x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()]
|
||||
x_input.requires_grad_(True)
|
||||
|
||||
# Multi-head Attention
|
||||
mha = SequenceParallelAttention(head_num, hidden_dim).cuda()
|
||||
# Multi-head Attention forward
|
||||
mha_out = mha(x_unshard)
|
||||
|
||||
# Sequence parallel Attention
|
||||
sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda()
|
||||
sp_attn.load_state_dict(copy.deepcopy(mha.state_dict()))
|
||||
# Sequence parallel Attention forward
|
||||
dist_attn_out = sp_attn(x_input)
|
||||
|
||||
# gather the output of sequence parallel attention
|
||||
out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)]
|
||||
dist.all_gather(out_list, dist_attn_out)
|
||||
seq_out = torch.cat(out_list, dim=1)
|
||||
|
||||
# forward result check
|
||||
assert_close(seq_out, mha_out)
|
||||
|
||||
# Multi-head Attention backward
|
||||
mha_out.sum().backward()
|
||||
q_grad = mha.q.weight.grad
|
||||
k_grad = mha.k.weight.grad
|
||||
v_grad = mha.v.weight.grad
|
||||
o_grad = mha.out.weight.grad
|
||||
x_grad = x_unshard.grad
|
||||
|
||||
# Sequence parallel Attention backward
|
||||
dist_attn_out.sum().backward()
|
||||
q_grad_seq = sp_attn.q.weight.grad
|
||||
k_grad_seq = sp_attn.k.weight.grad
|
||||
v_grad_seq = sp_attn.v.weight.grad
|
||||
o_grad_seq = sp_attn.out.weight.grad
|
||||
x_grad_seq = x_input.grad
|
||||
# all_reduce the grad of sequence parallel attention weight
|
||||
dist.all_reduce(q_grad_seq)
|
||||
dist.all_reduce(k_grad_seq)
|
||||
dist.all_reduce(v_grad_seq)
|
||||
dist.all_reduce(o_grad_seq)
|
||||
# gather the grad of sequence parallel attention input
|
||||
x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)]
|
||||
dist.all_gather(x_grad_seq_list, x_grad_seq)
|
||||
x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1)
|
||||
|
||||
# backward result check
|
||||
assert_close(q_grad_seq, q_grad)
|
||||
assert_close(k_grad_seq, k_grad)
|
||||
assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4)
|
||||
assert_close(o_grad_seq, o_grad)
|
||||
assert_close(x_grad_seq_gather, x_grad)
|
||||
|
||||
|
||||
@parameterize("seq_len", [128])
|
||||
@parameterize("hidden_dim", [64])
|
||||
@parameterize("head_num", [4])
|
||||
@parameterize("batch_size", [1])
|
||||
def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
|
||||
seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size)
|
||||
|
||||
|
||||
def check_all2all_attn(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_seq_parallel_attn()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_to_all_attention():
|
||||
spawn(check_all2all_attn, nprocs=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_to_all_attention()
|
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||
sharded_model = copy.deepcopy(org_model)
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
|
||||
org_model = org_model.cuda()
|
||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||
@@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
|
||||
data = data_gen_fn()
|
||||
|
||||
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
||||
seq_len = data["input_ids"].shape[-1]
|
||||
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
||||
times = lcm // seq_len
|
||||
input_shape = data["input_ids"].shape
|
||||
for k, v in data.items():
|
||||
if v.shape == input_shape:
|
||||
data[k] = v.repeat((1,) * (v.dim() - 1) + (times,))
|
||||
shard_test_data = {}
|
||||
for k, v in data.items():
|
||||
shard_test_data[k] = data[k].clone()
|
||||
unshard_test_data = {}
|
||||
for k, v in data.items():
|
||||
unshard_test_data[k] = data[k].clone()
|
||||
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
for k, v in shard_test_data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to("cuda").repeat(*new_shape)
|
||||
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
data_iter = iter([data])
|
||||
data_iter = iter([shard_test_data])
|
||||
sharded_output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
sharded_model,
|
||||
@@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
return_outputs=True,
|
||||
)
|
||||
sharded_loss = sharded_output["loss"]
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
sharded_output = sharded_model(**data)
|
||||
|
||||
else:
|
||||
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
|
||||
sharded_output = sharded_model(**shard_test_data)
|
||||
sharded_loss = criterion(sharded_output)
|
||||
sharded_optimizer.backward(sharded_loss)
|
||||
|
||||
org_model.train()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
org_output = org_model(**data)
|
||||
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in unshard_test_data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
||||
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
|
||||
org_output = org_model(**unshard_test_data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
@@ -212,7 +213,6 @@ def check_output_hidden_state(
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
dim: int = 0,
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
|
@@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
@@ -154,7 +176,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
def run_bert_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
|
@@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
@@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
@@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
@@ -2,6 +2,8 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@@ -46,6 +48,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
|
||||
norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"]
|
||||
|
||||
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
|
||||
if stage_manager is None:
|
||||
norm_layer_for_check.append("norm")
|
||||
|
||||
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||
if (
|
||||
booster.plugin.zero_stage in [1, 2]
|
||||
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
@@ -60,8 +82,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
llama_model,
|
||||
shard_llama_model,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
@@ -98,6 +131,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
Reference in New Issue
Block a user