diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 468b890ab..52ea6c22b 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -535,28 +535,22 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - rank = dist.get_rank() - print(f"{rank=}, hidden states:{hidden_states.shape}") query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - rank = dist.get_rank() - print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group) key_states = all_to_all_comm(key_states, sp_group) value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() - print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -565,7 +559,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0 ) - print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models @@ -576,9 +569,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - print( - f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}" - ) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons @@ -606,7 +596,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}") attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index d1d004ed5..963bd9d67 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -50,7 +50,6 @@ class DeepseekPolicy(Policy): "sdpa": "DeepseekSdpaAttention", } policy = {} - print(f"{self.origin_attn_implement=}") attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None @@ -71,7 +70,6 @@ class DeepseekPolicy(Policy): # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism # if both are enabled, one of them will be ignored raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.") - print(f"{attn_cls=}") self.append_or_create_method_replacement( description={ "forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), @@ -208,13 +206,8 @@ class DeepseekPolicy(Policy): @staticmethod def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module: - flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx) - flash_attn_module.q_proj = original_attn.q_proj - flash_attn_module.k_proj = original_attn.k_proj - flash_attn_module.v_proj = original_attn.v_proj - flash_attn_module.o_proj = original_attn.o_proj - flash_attn_module.rotary_emb = original_attn.rotary_emb - return flash_attn_module + original_attn.__class__ = flash_attn_cls + return original_attn self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py index b8b446b57..ad73640a5 100644 --- a/tests/kit/model_zoo/transformers/deepseek.py +++ b/tests/kit/model_zoo/transformers/deepseek.py @@ -68,7 +68,6 @@ def init_deepseek(): if hasattr(config, "pad_token_id"): config.pad_token_id = config.eos_token_id - print(config) model = transformers.AutoModel.from_config(config, trust_remote_code=True) return model