Merge pull request #6276 from flybird11111/upgrade-transformers

Upgrade transformers
This commit is contained in:
Hanks 2025-04-24 17:30:40 +08:00 committed by GitHub
commit 8497ecc3e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 51 deletions

View File

@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -141,7 +140,9 @@ class LlamaPipelineForwards:
invert=(sp_mode != "ring_attn"),
)
else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
attn_kwargs: torch.Tensor = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values
)
# Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage()
@ -177,6 +178,7 @@ class LlamaPipelineForwards:
all_self_attns = () if output_attentions else None
next_decoder_cache = None
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
position_embeddings = self.rotary_emb(hidden_states, position_ids)
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
@ -204,6 +206,7 @@ class LlamaPipelineForwards:
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
@ -214,6 +217,7 @@ class LlamaPipelineForwards:
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
@ -486,8 +490,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
@ -505,30 +509,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
)
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
# sp: modify sp_len when sequence parallel mode is ring
if is_share_sp_tp(sp_mode):
q_len *= sp_size
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
@ -537,9 +525,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
@ -552,7 +540,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@ -610,17 +598,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights
return forward

View File

@ -33,22 +33,9 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaModel,
LlamaSdpaAttention,
)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
ATTN_IMPLEMENTATION = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
@ -82,7 +69,7 @@ class LlamaPolicy(Policy):
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription(
policy[LlamaAttention] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
@ -91,7 +78,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
target_key=LlamaAttention,
)
if self.pipeline_stage_manager is None:
@ -354,6 +341,7 @@ class LlamaPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))