ColossalAI/tests/kit/model_zoo/transformers/llama.py
Zhongkai Zhao 8e412a548e
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 17:15:47 +08:00

101 lines
3.3 KiB
Python

import torch
import transformers
from ..registry import ModelAttribute, model_zoo
try:
from transformers import LlamaConfig
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False
if HAS_LLAMA:
# ===============================
# Register LLaMA
# ===============================
def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import LlamaTokenizerFast
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
input_ids = torch.Tensor(
[
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
]
).long()
attention_mask = torch.Tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
return data
# transform the output to a dict
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
num_hidden_layers=8,
hidden_size=32,
intermediate_size=64,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(
name="transformers_llama",
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama_for_casual_lm",
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama_for_sequence_classification",
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)