upgrade llama

This commit is contained in:
flybird11111
2025-04-24 14:54:15 +08:00
parent 0c5ed65305
commit 686982764c
6 changed files with 46 additions and 47 deletions

View File

@@ -36,19 +36,19 @@ class LlamaPolicy(Policy):
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
# LlamaFlashAttention2,
LlamaModel,
LlamaSdpaAttention,
# LlamaSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
# ATTN_IMPLEMENTATION = {
# "eager": LlamaAttention,
# "flash_attention_2": LlamaFlashAttention2,
# "sdpa": LlamaSdpaAttention,
# }
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
attn_cls = LlamaAttention
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
@@ -354,6 +354,7 @@ class LlamaPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))