mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Shardformer]fix the num_heads assert for llama model and qwen model (#5704)
* fix the num_heads assert * fix the transformers import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the import --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -10,16 +10,20 @@ from transformers.modeling_outputs import (
|
||||
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Model = "Qwen2Model"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
@@ -451,10 +455,6 @@ class Qwen2PipelineForwards:
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user