[shardformer] fix import (#5788)

This commit is contained in:
Hongxin Liu
2024-06-06 19:09:50 +08:00
committed by GitHub
parent 5ead00ffc5
commit 73e88a5553
2 changed files with 8 additions and 4 deletions

View File

@@ -9,13 +9,15 @@ from transformers.modeling_outputs import (
)
try:
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)