[upgrade]Upgrade mixtral (#6317)

* upgrade mixtral

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade infer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* upgrade drafter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* upgrade lazy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade mixtral

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2025-05-21 16:14:05 +08:00 committed by GitHub
parent 2aa295e959
commit d0e13b85fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 92 additions and 126 deletions

View File

@ -6,19 +6,16 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForCausalLM,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from colossalai.inference.spec import GlideInput
@ -156,15 +153,11 @@ def glide_llama_model_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
@ -172,15 +165,17 @@ def glide_llama_model_forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if hasattr(glide_input, "n_spec_tokens"):
position_ids = position_ids + glide_input.n_spec_tokens
# embed positions
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
@ -189,9 +184,9 @@ def glide_llama_model_forward(
# GlideLlamaDecoderLayer
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
@ -200,9 +195,6 @@ def glide_llama_model_forward(
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -212,16 +204,11 @@ def glide_llama_model_forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -267,31 +254,6 @@ class LlamaCrossAttention(nn.Module):
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -299,9 +261,10 @@ class LlamaCrossAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Optional[torch.Tensor]:
@ -319,8 +282,7 @@ class LlamaCrossAttention(nn.Module):
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
# for RoPE
position_ids = position_ids + glide_input.n_spec_tokens
cos, sin = self.rotary_emb(query_states, position_ids)
cos, sin = position_embeddings
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
@ -367,9 +329,10 @@ class GlideLlamaDecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@ -399,10 +362,10 @@ class GlideLlamaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
@ -425,9 +388,10 @@ class GlideLlamaDecoderLayer(nn.Module):
hidden_states = self.cross_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
position_ids=position_ids,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=True,
)
@ -441,9 +405,6 @@ class GlideLlamaDecoderLayer(nn.Module):
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@ -478,9 +478,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_oproj=attn_oproj,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
num_heads=module.config.num_attention_heads,
hidden_size=module.config.hidden_size,
num_key_value_heads=module.config.num_key_value_heads,
)
return attn_layer

View File

@ -3,6 +3,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer
from transformers.cache_utils import DynamicCache
from colossalai.utils import get_current_device
@ -93,9 +94,8 @@ class Drafter:
for _ in range(n_spec_tokens):
# update past key values
kwargs["past_key_values"] = past_key_values
outputs = self._drafter_model(input_ids, **kwargs)
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating.
@ -114,6 +114,8 @@ class Drafter:
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
if isinstance(past_key_values, DynamicCache):
past_key_values = past_key_values.to_legacy_cache()
out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values

View File

@ -69,7 +69,7 @@ def new_from_pretrained(
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
@ -286,7 +286,8 @@ def new_from_pretrained(
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
# init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts = [no_init_weights()]
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

View File

@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@ -13,6 +13,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralModel,
MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
@ -215,7 +216,7 @@ class MixtralPipelineForwards:
@staticmethod
def mixtral_model_forward(
self,
self: MixtralModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -225,6 +226,7 @@ class MixtralPipelineForwards:
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
@ -340,11 +342,17 @@ class MixtralPipelineForwards:
)
use_cache = False
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
@ -370,6 +378,9 @@ class MixtralPipelineForwards:
None,
output_attentions,
output_router_logits,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
@ -380,6 +391,8 @@ class MixtralPipelineForwards:
output_attentions,
output_router_logits,
use_cache,
cache_position,
position_embeddings,
)
hidden_states = layer_outputs[0]
@ -559,14 +572,18 @@ class MixtralPipelineForwards:
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.mixtral.modeling_mixtral import eager_attention_forward
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
@ -614,54 +631,23 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
@ -689,14 +675,27 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward(
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
**kwargs,
)
# sp: all-to-all comminucation when introducing sequence parallel
@ -712,7 +711,7 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights
return forward
@ -731,6 +730,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -788,7 +788,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == "flash_attention_2":
if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
@ -820,6 +820,16 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
)
hidden_states = inputs_embeds
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -840,6 +850,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
output_attentions,
output_router_logits,
use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
else:
layer_outputs = decoder_layer(
@ -850,6 +862,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]

View File

@ -40,21 +40,9 @@ class MixtralPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralFlashAttention2,
MixtralModel,
MixtralSdpaAttention,
)
from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel
ATTN_IMPLEMENTATION = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
"sdpa": MixtralSdpaAttention,
}
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
@ -76,7 +64,7 @@ class MixtralPolicy(Policy):
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription(
policy[MixtralAttention] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_sequence_parallelism:
@ -89,7 +77,7 @@ class MixtralPolicy(Policy):
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
target_key=MixtralAttention,
)
self.append_or_create_method_replacement(
description={
@ -330,7 +318,7 @@ class MixtralPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))