mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
[chore] format qwen2
This commit is contained in:
parent
97f4bee9d8
commit
d436f6d08c
@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Model = "Qwen2Model"
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@ -434,7 +426,6 @@ class Qwen2PipelineForwards:
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
print(self.config.pad_token_id)
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
|
@ -4,6 +4,13 @@ from typing import Callable, Dict, List, Union
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2DecoderLayer,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
@ -21,26 +28,6 @@ from ..modeling.qwen2 import (
|
||||
get_qwen2_flash_attention_forward,
|
||||
get_qwen2_model_forward_for_flash_attn,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2DecoderLayer,
|
||||
Qwen2FlashAttention2,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
Qwen2SdpaAttention,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
||||
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
||||
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
||||
Qwen2Model = "Qwen2Model"
|
||||
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
||||
@ -295,7 +282,6 @@ class Qwen2Policy(Policy):
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
|
Loading…
Reference in New Issue
Block a user