upgrade llama

This commit is contained in:
flybird11111 2025-04-24 14:54:15 +08:00
parent 0c5ed65305
commit 686982764c
6 changed files with 46 additions and 47 deletions

View File

@ -1056,6 +1056,8 @@ class HybridParallelPlugin(PipelinePluginBase):
assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if sp_size is None or sp_size <= 1:
enable_sequence_parallelism = False
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"

View File

@ -94,6 +94,7 @@ class LlamaPipelineForwards:
batch_size, seq_length = input_shape
device = hidden_states.device
# Support SP + PP
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
@ -112,6 +113,7 @@ class LlamaPipelineForwards:
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens
if output_attentions:
@ -141,7 +143,7 @@ class LlamaPipelineForwards:
invert=(sp_mode != "ring_attn"),
)
else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
# Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage()
@ -177,6 +179,7 @@ class LlamaPipelineForwards:
all_self_attns = () if output_attentions else None
next_decoder_cache = None
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
position_embeddings = self.rotary_emb(hidden_states, position_ids)
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
@ -204,6 +207,7 @@ class LlamaPipelineForwards:
output_attentions,
use_cache,
cache_position,
position_embeddings
)
else:
layer_outputs = decoder_layer(
@ -214,6 +218,7 @@ class LlamaPipelineForwards:
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings
)
hidden_states = layer_outputs[0]
@ -486,8 +491,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
@ -505,30 +510,21 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
)
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
# sp: modify sp_len when sequence parallel mode is ring
if is_share_sp_tp(sp_mode):
q_len *= sp_size
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
# if sp_mode == "all_to_all":
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
# # bsz, q_len, _ = query_states.size()
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
@ -537,9 +533,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
@ -552,7 +548,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
# cos, sin = self.rotary_emb(value_states, position_ids)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@ -610,17 +607,13 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# return attn_output, attn_weights, past_key_value
return attn_output, attn_weights
return forward

View File

@ -36,19 +36,19 @@ class LlamaPolicy(Policy):
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
# LlamaFlashAttention2,
LlamaModel,
LlamaSdpaAttention,
# LlamaSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
# ATTN_IMPLEMENTATION = {
# "eager": LlamaAttention,
# "flash_attention_2": LlamaFlashAttention2,
# "sdpa": LlamaSdpaAttention,
# }
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
attn_cls = LlamaAttention
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
@ -354,6 +354,7 @@ class LlamaPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))

View File

@ -225,6 +225,7 @@ class ModelSharder(object):
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
print("held_layers", held_layers)
set_tensors_to_none(self.model, exclude=set(held_layers))
return set(self._get_recursive_held_layers(held_layers))
return None

View File

@ -16,4 +16,6 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No
for n, p in model.named_parameters(recurse=False):
setattr(model, n, None)
for n, buf in model.named_buffers(recurse=False):
import torch
print("buffer", n, torch.distributed.get_rank())
setattr(model, n, None)

View File

@ -162,9 +162,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
# Double Ring Attention
{
"tp_size": 1,
"tp_size": 2,
"pp_size": 1,
"sp_size": 4,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
@ -226,12 +226,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"initial_scale": 1,
},
{
"tp_size": 2,
"tp_size": 1,
"pp_size": 1,
"sp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,