mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-25 04:24:51 +00:00
[chore] format qwen2
This commit is contained in:
parent
97f4bee9d8
commit
d436f6d08c
@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
)
|
)
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
try:
|
Qwen2Attention,
|
||||||
from transformers.modeling_attn_mask_utils import (
|
Qwen2ForCausalLM,
|
||||||
_prepare_4d_causal_attention_mask,
|
Qwen2ForSequenceClassification,
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
Qwen2Model,
|
||||||
)
|
apply_rotary_pos_emb,
|
||||||
from transformers.models.qwen2.modeling_qwen2 import (
|
repeat_kv,
|
||||||
Qwen2Attention,
|
)
|
||||||
Qwen2ForCausalLM,
|
|
||||||
Qwen2ForSequenceClassification,
|
|
||||||
Qwen2Model,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
Qwen2Model = "Qwen2Model"
|
|
||||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
|
||||||
Qwen2Attention = "Qwen2Attention"
|
|
||||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
|
||||||
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
@ -434,7 +426,6 @@ class Qwen2PipelineForwards:
|
|||||||
logits = self.score(hidden_states)
|
logits = self.score(hidden_states)
|
||||||
|
|
||||||
if self.config.pad_token_id is None and batch_size != 1:
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
print(self.config.pad_token_id)
|
|
||||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
if self.config.pad_token_id is None:
|
if self.config.pad_token_id is None:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
|
@ -4,6 +4,13 @@ from typing import Callable, Dict, List, Union
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
|
Qwen2Attention,
|
||||||
|
Qwen2DecoderLayer,
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2Model,
|
||||||
|
)
|
||||||
|
|
||||||
from colossalai.shardformer.layer import (
|
from colossalai.shardformer.layer import (
|
||||||
FusedRMSNorm,
|
FusedRMSNorm,
|
||||||
@ -21,26 +28,6 @@ from ..modeling.qwen2 import (
|
|||||||
get_qwen2_flash_attention_forward,
|
get_qwen2_flash_attention_forward,
|
||||||
get_qwen2_model_forward_for_flash_attn,
|
get_qwen2_model_forward_for_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.qwen2.modeling_qwen2 import (
|
|
||||||
Qwen2Attention,
|
|
||||||
Qwen2DecoderLayer,
|
|
||||||
Qwen2FlashAttention2,
|
|
||||||
Qwen2ForCausalLM,
|
|
||||||
Qwen2ForSequenceClassification,
|
|
||||||
Qwen2Model,
|
|
||||||
Qwen2SdpaAttention,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
|
||||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
|
||||||
Qwen2Attention = "Qwen2Attention"
|
|
||||||
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
|
||||||
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
|
||||||
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
|
||||||
Qwen2Model = "Qwen2Model"
|
|
||||||
|
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
||||||
@ -295,7 +282,6 @@ class Qwen2Policy(Policy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
|
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
|
Loading…
Reference in New Issue
Block a user