This commit is contained in:
flybird11111 2025-04-24 16:20:42 +08:00
parent d7a9eb0f67
commit 2f615a49fd
5 changed files with 2 additions and 26 deletions

View File

@ -1056,8 +1056,6 @@ class HybridParallelPlugin(PipelinePluginBase):
assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
# if sp_size is None or sp_size <= 1:
# enable_sequence_parallelism = False
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"

View File

@ -515,13 +515,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
if is_share_sp_tp(sp_mode):
q_len *= sp_size
# if sp_mode == "all_to_all":
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # bsz, q_len, _ = query_states.size()
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
@ -548,7 +541,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# cos, sin = self.rotary_emb(value_states, position_ids)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -607,14 +599,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
# return attn_output, attn_weights, past_key_value
return attn_output, attn_weights
return forward

View File

@ -36,19 +36,10 @@ class LlamaPolicy(Policy):
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
# LlamaFlashAttention2,
LlamaModel,
# LlamaSdpaAttention,
)
# ATTN_IMPLEMENTATION = {
# "eager": LlamaAttention,
# "flash_attention_2": LlamaFlashAttention2,
# "sdpa": LlamaSdpaAttention,
# }
policy = {}
attn_cls = LlamaAttention
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
@ -82,7 +73,7 @@ class LlamaPolicy(Policy):
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription(
policy[LlamaAttention] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
@ -91,7 +82,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
target_key=LlamaAttention,
)
if self.pipeline_stage_manager is None:

View File

@ -225,7 +225,6 @@ class ModelSharder(object):
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
print("held_layers", held_layers)
set_tensors_to_none(self.model, exclude=set(held_layers))
return set(self._get_recursive_held_layers(held_layers))
return None

View File

@ -16,6 +16,4 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No
for n, p in model.named_parameters(recurse=False):
setattr(model, n, None)
for n, buf in model.named_buffers(recurse=False):
import torch
print("buffer", n, torch.distributed.get_rank())
setattr(model, n, None)