mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[feat] support qwen3 in shardformer
This commit is contained in:
@@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Model = "Qwen2Model"
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -434,7 +426,6 @@ class Qwen2PipelineForwards:
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
print(self.config.pad_token_id)
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
|
Reference in New Issue
Block a user