mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
fix
This commit is contained in:
@@ -36,19 +36,10 @@ class LlamaPolicy(Policy):
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
# LlamaFlashAttention2,
|
||||
LlamaModel,
|
||||
# LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
# ATTN_IMPLEMENTATION = {
|
||||
# "eager": LlamaAttention,
|
||||
# "flash_attention_2": LlamaFlashAttention2,
|
||||
# "sdpa": LlamaSdpaAttention,
|
||||
# }
|
||||
policy = {}
|
||||
attn_cls = LlamaAttention
|
||||
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
@@ -82,7 +73,7 @@ class LlamaPolicy(Policy):
|
||||
num_kv_heads //= sp_size
|
||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
policy[LlamaAttention] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
@@ -91,7 +82,7 @@ class LlamaPolicy(Policy):
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
target_key=LlamaAttention,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager is None:
|
||||
|
Reference in New Issue
Block a user