[shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
Zhongkai Zhao
2024-04-03 17:15:47 +08:00
committed by GitHub
parent 7e0ec5a85c
commit 8e412a548e
33 changed files with 1630 additions and 256 deletions

View File

@@ -18,8 +18,23 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
# input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
# attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
input_ids = torch.tensor(
[
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
],
dtype=torch.int64,
)
attention_mask = torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
dtype=torch.int64,
)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -35,9 +50,9 @@ def data_gen_for_question_answering():
# question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
start_positions = torch.tensor([[0], [0]], dtype=torch.int64)
data["start_positions"] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([[1], [1]], dtype=torch.int64)
data["end_positions"] = end_positions
return data
@@ -46,14 +61,20 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
data["labels"] = torch.tensor(
[
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
],
dtype=torch.int64,
)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64)
return data
@@ -61,12 +82,18 @@ def date_gen_for_double_heads():
num_choices = 2
batch_size = 2
input_ids = torch.tensor(
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
[
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
],
dtype=torch.int64,
)
attention_mask = torch.tensor(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.int64,
)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
@@ -103,6 +130,7 @@ config = transformers.GPT2Config(
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256,
tie_word_embeddings=True,
)
config_for_token_classification = copy.deepcopy(config)