[chore] format qwen2

This commit is contained in:
botbw 2025-07-04 16:38:46 +08:00
parent 97f4bee9d8
commit d436f6d08c
2 changed files with 19 additions and 42 deletions

View File

@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
try:
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -434,7 +426,6 @@ class Qwen2PipelineForwards:
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
print(self.config.pad_token_id)
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1

View File

@ -4,6 +4,13 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
)
from colossalai.shardformer.layer import (
FusedRMSNorm,
@ -21,26 +28,6 @@ from ..modeling.qwen2 import (
get_qwen2_flash_attention_forward,
get_qwen2_model_forward_for_flash_attn,
)
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
@ -295,7 +282,6 @@ class Qwen2Policy(Policy):
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
self.append_or_create_method_replacement(
description={
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),