mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 23:16:56 +00:00
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
import torch
|
|
import transformers
|
|
|
|
from ..registry import ModelAttribute, model_zoo
|
|
|
|
try:
|
|
from transformers import LlamaConfig
|
|
|
|
HAS_LLAMA = True
|
|
except ImportError:
|
|
HAS_LLAMA = False
|
|
|
|
if HAS_LLAMA:
|
|
# ===============================
|
|
# Register LLaMA
|
|
# ===============================
|
|
|
|
def data_gen():
|
|
# the input ids are corresponding to the sentence
|
|
# 'Hello, my dog is cute'
|
|
#
|
|
# the code is give below:
|
|
# -----------------------------------
|
|
# from transformers import LlamaTokenizerFast
|
|
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
|
# input = 'Hello, my dog is cute'
|
|
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
|
# -----------------------------------
|
|
|
|
input_ids = torch.Tensor(
|
|
[
|
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
|
]
|
|
).long()
|
|
|
|
attention_mask = torch.Tensor(
|
|
[
|
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
]
|
|
).long()
|
|
|
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
# label is needed for casual lm
|
|
def data_gen_for_casual_lm():
|
|
data = data_gen()
|
|
labels = data["input_ids"].clone()
|
|
data["labels"] = labels
|
|
return data
|
|
|
|
# transform the output to a dict
|
|
output_transform_fn = lambda x: x
|
|
|
|
# function to get the loss
|
|
loss_fn = lambda output: output["last_hidden_state"].mean()
|
|
loss_fn_for_casual_lm = lambda output: output["loss"]
|
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
|
|
|
config = LlamaConfig(
|
|
num_hidden_layers=8,
|
|
hidden_size=32,
|
|
intermediate_size=64,
|
|
num_attention_heads=4,
|
|
max_position_embeddings=128,
|
|
num_labels=16,
|
|
)
|
|
|
|
if hasattr(config, "pad_token_id"):
|
|
config.pad_token_id = config.eos_token_id
|
|
|
|
# register the following models
|
|
# transformers.LlamaModel,
|
|
# transformers.LlamaForCausalLM,
|
|
# transformers.LlamaForSequenceClassification,
|
|
model_zoo.register(
|
|
name="transformers_llama",
|
|
model_fn=lambda: transformers.LlamaModel(config),
|
|
data_gen_fn=data_gen,
|
|
output_transform_fn=output_transform_fn,
|
|
loss_fn=loss_fn,
|
|
model_attribute=ModelAttribute(has_control_flow=True),
|
|
)
|
|
model_zoo.register(
|
|
name="transformers_llama_for_casual_lm",
|
|
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
|
data_gen_fn=data_gen_for_casual_lm,
|
|
output_transform_fn=output_transform_fn,
|
|
loss_fn=loss_fn_for_casual_lm,
|
|
model_attribute=ModelAttribute(has_control_flow=True),
|
|
)
|
|
model_zoo.register(
|
|
name="transformers_llama_for_sequence_classification",
|
|
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
|
data_gen_fn=data_gen,
|
|
output_transform_fn=output_transform_fn,
|
|
loss_fn=loss_fn_for_seq_classification,
|
|
model_attribute=ModelAttribute(has_control_flow=True),
|
|
)
|