mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 13:45:51 +00:00
[misc] remove debug/print code
This commit is contained in:
parent
59bcf56c60
commit
034020bd04
@ -535,28 +535,22 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||||||
if sp_mode in ["split_gather", "ring"]:
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
print(f"{rank=}, hidden states:{hidden_states.shape}")
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
query_states = all_to_all_comm(query_states, sp_group)
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
key_states = all_to_all_comm(key_states, sp_group)
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
value_states = all_to_all_comm(value_states, sp_group)
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
|
||||||
# Flash attention requires the input to have the shape
|
# Flash attention requires the input to have the shape
|
||||||
# batch_size x seq_length x head_dim x hidden_dim
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
# therefore we just need to keep the original shape
|
# therefore we just need to keep the original shape
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -565,7 +559,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
|
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
|
||||||
)
|
)
|
||||||
print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
@ -576,9 +569,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
print(
|
|
||||||
f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
|
|
||||||
)
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
@ -606,7 +596,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||||||
query_states = query_states.to(target_dtype)
|
query_states = query_states.to(target_dtype)
|
||||||
key_states = key_states.to(target_dtype)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
|
||||||
attn_output = self._flash_attention_forward(
|
attn_output = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
@ -50,7 +50,6 @@ class DeepseekPolicy(Policy):
|
|||||||
"sdpa": "DeepseekSdpaAttention",
|
"sdpa": "DeepseekSdpaAttention",
|
||||||
}
|
}
|
||||||
policy = {}
|
policy = {}
|
||||||
print(f"{self.origin_attn_implement=}")
|
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
@ -71,7 +70,6 @@ class DeepseekPolicy(Policy):
|
|||||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||||
# if both are enabled, one of them will be ignored
|
# if both are enabled, one of them will be ignored
|
||||||
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
|
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
|
||||||
print(f"{attn_cls=}")
|
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
@ -208,13 +206,8 @@ class DeepseekPolicy(Policy):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
||||||
flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
|
original_attn.__class__ = flash_attn_cls
|
||||||
flash_attn_module.q_proj = original_attn.q_proj
|
return original_attn
|
||||||
flash_attn_module.k_proj = original_attn.k_proj
|
|
||||||
flash_attn_module.v_proj = original_attn.v_proj
|
|
||||||
flash_attn_module.o_proj = original_attn.o_proj
|
|
||||||
flash_attn_module.rotary_emb = original_attn.rotary_emb
|
|
||||||
return flash_attn_module
|
|
||||||
|
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
|
@ -68,7 +68,6 @@ def init_deepseek():
|
|||||||
|
|
||||||
if hasattr(config, "pad_token_id"):
|
if hasattr(config, "pad_token_id"):
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
print(config)
|
|
||||||
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
|
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user