diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 7bdf1e65f..5417bf4eb 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -4,31 +4,23 @@ from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) - -try: - from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, - ) - from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2Attention, - Qwen2ForCausalLM, - Qwen2ForSequenceClassification, - Qwen2Model, - apply_rotary_pos_emb, - repeat_kv, - ) -except ImportError: - Qwen2Model = "Qwen2Model" - Qwen2ForCausalLM = "Qwen2ForCausalLM" - Qwen2Attention = "Qwen2Attention" - Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" - +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen2Model, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -434,7 +426,6 @@ class Qwen2PipelineForwards: logits = self.score(hidden_states) if self.config.pad_token_id is None and batch_size != 1: - print(self.config.pad_token_id) raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 7f8a35e46..78b3bf528 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -4,6 +4,13 @@ from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor from torch.nn import Module +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen2Model, +) from colossalai.shardformer.layer import ( FusedRMSNorm, @@ -21,26 +28,6 @@ from ..modeling.qwen2 import ( get_qwen2_flash_attention_forward, get_qwen2_model_forward_for_flash_attn, ) - -try: - from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2Attention, - Qwen2DecoderLayer, - Qwen2FlashAttention2, - Qwen2ForCausalLM, - Qwen2ForSequenceClassification, - Qwen2Model, - Qwen2SdpaAttention, - ) -except ImportError: - Qwen2ForCausalLM = "Qwen2ForCausalLM" - Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" - Qwen2Attention = "Qwen2Attention" - Qwen2FlashAttention2 = "Qwen2FlashAttention2" - Qwen2SdpaAttention = "Qwen2SdpaAttention" - Qwen2DecoderLayer = "Qwen2DecoderLayer" - Qwen2Model = "Qwen2Model" - from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"] @@ -295,7 +282,6 @@ class Qwen2Policy(Policy): ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: - print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention) self.append_or_create_method_replacement( description={ "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),