[chore] format qwen2

This commit is contained in:
botbw 2025-07-04 16:38:46 +08:00
parent 97f4bee9d8
commit d436f6d08c
2 changed files with 19 additions and 42 deletions

View File

@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from transformers.models.qwen2.modeling_qwen2 import (
try: Qwen2Attention,
from transformers.modeling_attn_mask_utils import ( Qwen2ForCausalLM,
_prepare_4d_causal_attention_mask, Qwen2ForSequenceClassification,
_prepare_4d_causal_attention_mask_for_sdpa, Qwen2Model,
) apply_rotary_pos_emb,
from transformers.models.qwen2.modeling_qwen2 import ( repeat_kv,
Qwen2Attention, )
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -434,7 +426,6 @@ class Qwen2PipelineForwards:
logits = self.score(hidden_states) logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1: if self.config.pad_token_id is None and batch_size != 1:
print(self.config.pad_token_id)
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None: if self.config.pad_token_id is None:
sequence_lengths = -1 sequence_lengths = -1

View File

@ -4,6 +4,13 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
)
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
@ -21,26 +28,6 @@ from ..modeling.qwen2 import (
get_qwen2_flash_attention_forward, get_qwen2_flash_attention_forward,
get_qwen2_model_forward_for_flash_attn, get_qwen2_model_forward_for_flash_attn,
) )
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"] __all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
@ -295,7 +282,6 @@ class Qwen2Policy(Policy):
) )
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),