mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
upgrade llama
This commit is contained in:
@@ -36,19 +36,19 @@ class LlamaPolicy(Policy):
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
# LlamaFlashAttention2,
|
||||
LlamaModel,
|
||||
LlamaSdpaAttention,
|
||||
# LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": LlamaAttention,
|
||||
"flash_attention_2": LlamaFlashAttention2,
|
||||
"sdpa": LlamaSdpaAttention,
|
||||
}
|
||||
# ATTN_IMPLEMENTATION = {
|
||||
# "eager": LlamaAttention,
|
||||
# "flash_attention_2": LlamaFlashAttention2,
|
||||
# "sdpa": LlamaSdpaAttention,
|
||||
# }
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
attn_cls = LlamaAttention
|
||||
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
@@ -354,6 +354,7 @@ class LlamaPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
held_layers.append(module.rotary_emb)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
|
Reference in New Issue
Block a user