[feat] support qwen3 in shardformer

This commit is contained in:
botbw
2025-07-10 13:57:52 +08:00
committed by GitHub
parent 97f4bee9d8
commit d097224d90
9 changed files with 1829 additions and 42 deletions

View File

@@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
try:
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -434,7 +426,6 @@ class Qwen2PipelineForwards:
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
print(self.config.pad_token_id)
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1