diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 7459d35de..8f8ab25a5 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -10,16 +10,20 @@ from transformers.modeling_outputs import ( try: from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2Model, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, + apply_rotary_pos_emb, + repeat_kv, ) except ImportError: Qwen2Model = "Qwen2Model" - Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" Qwen2ForCausalLM = "Qwen2ForCausalLM" + Qwen2Attention = "Qwen2Attention" + Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" from transformers.utils import logging @@ -451,10 +455,6 @@ class Qwen2PipelineForwards: def get_qwen2_flash_attention_forward(shard_config: ShardConfig): - from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv - - from colossalai.shardformer.layer import ColoAttention - def forward( self: Qwen2Attention, hidden_states: torch.Tensor, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 6e541f792..713175c6c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -141,9 +141,11 @@ class LlamaPolicy(Policy): assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 ), f"The number of attention heads must be divisible by tensor parallel size." - assert ( - self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 - ), f"The number of key_value heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size + and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 01f9abb13..3e427c4a1 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -21,6 +21,26 @@ from ..modeling.qwen2 import ( get_qwen2_flash_attention_forward, get_qwen2_model_forward_for_flash_attn, ) + +try: + from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen2Model, + Qwen2SdpaAttention, + ) +except ImportError: + Qwen2ForCausalLM = "Qwen2ForCausalLM" + Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" + Qwen2Attention = "Qwen2Attention" + Qwen2FlashAttention2 = "Qwen2FlashAttention2" + Qwen2SdpaAttention = "Qwen2SdpaAttention" + Qwen2DecoderLayer = "Qwen2DecoderLayer" + Qwen2Model = "Qwen2Model" + from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"] @@ -45,21 +65,6 @@ class Qwen2Policy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - try: - from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2Attention, - Qwen2DecoderLayer, - Qwen2FlashAttention2, - Qwen2Model, - Qwen2SdpaAttention, - ) - except ImportError: - Qwen2Attention = "Qwen2Attention" - Qwen2FlashAttention2 = "Qwen2FlashAttention2" - Qwen2SdpaAttention = "Qwen2SdpaAttention" - Qwen2DecoderLayer = "Qwen2DecoderLayer" - Qwen2Model = "Qwen2Model" - ATTN_IMPLEMENTATION = { "eager": Qwen2Attention, "flash_attention_2": Qwen2FlashAttention2, @@ -82,6 +87,13 @@ class Qwen2Policy(Policy): warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -256,7 +268,6 @@ class Qwen2Policy(Policy): class Qwen2ModelPolicy(Qwen2Policy): def module_policy(self): policy = super().module_policy() - from transformers.models.qwen2.modeling_qwen2 import Qwen2Model if self.pipeline_stage_manager: # set None as default @@ -277,10 +288,7 @@ class Qwen2ModelPolicy(Qwen2Policy): class Qwen2ForCausalLMPolicy(Qwen2Policy): def module_policy(self): - from transformers import Qwen2ForCausalLM - policy = super().module_policy() - setattr(self.shard_config, "causal_lm", True) if self.shard_config.enable_tensor_parallelism: @@ -330,10 +338,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): def module_policy(self): - from transformers import Qwen2ForSequenceClassification - policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = {