[Shardformer]fix the num_heads assert for llama model and qwen model (#5704)

* fix the num_heads assert

* fix the transformers import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the import

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wang Binluo
2024-05-10 15:33:39 +08:00
committed by GitHub
parent a3cc68ca93
commit 537f6a3855
3 changed files with 37 additions and 30 deletions

View File

@@ -10,16 +10,20 @@ from transformers.modeling_outputs import (
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.utils import logging
@@ -451,10 +455,6 @@ class Qwen2PipelineForwards:
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
from colossalai.shardformer.layer import ColoAttention
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,