[upgrade]Upgrade mixtral (#6317)

* upgrade mixtral

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade infer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* upgrade drafter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* upgrade lazy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade mixtral

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2025-05-21 16:14:05 +08:00 committed by GitHub
parent 2aa295e959
commit d0e13b85fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 92 additions and 126 deletions

View File

@ -6,19 +6,16 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaConfig, LlamaConfig,
LlamaDecoderLayer, LlamaDecoderLayer,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForCausalLM, LlamaForCausalLM,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP, LlamaMLP,
LlamaModel, LlamaModel,
LlamaRMSNorm, LlamaRMSNorm,
LlamaRotaryEmbedding,
) )
from colossalai.inference.spec import GlideInput from colossalai.inference.spec import GlideInput
@ -156,15 +153,11 @@ def glide_llama_model_forward(
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache: # kept for BC (cache positions) past_key_values = DynamicCache()
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
@ -172,15 +165,17 @@ def glide_llama_model_forward(
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if hasattr(glide_input, "n_spec_tokens"):
position_ids = position_ids + glide_input.n_spec_tokens
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers: for decoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
@ -189,9 +184,9 @@ def glide_llama_model_forward(
# GlideLlamaDecoderLayer # GlideLlamaDecoderLayer
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings=position_embeddings,
glide_input=glide_input, glide_input=glide_input,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@ -200,9 +195,6 @@ def glide_llama_model_forward(
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
@ -212,16 +204,11 @@ def glide_llama_model_forward(
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=past_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
@ -267,31 +254,6 @@ class LlamaCrossAttention(nn.Module):
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False) self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -299,9 +261,10 @@ class LlamaCrossAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
@ -319,8 +282,7 @@ class LlamaCrossAttention(nn.Module):
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
# for RoPE # for RoPE
position_ids = position_ids + glide_input.n_spec_tokens cos, sin = position_embeddings
cos, sin = self.rotary_emb(query_states, position_ids)
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim) query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
@ -367,9 +329,10 @@ class GlideLlamaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None, glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
@ -399,10 +362,10 @@ class GlideLlamaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@ -425,9 +388,10 @@ class GlideLlamaDecoderLayer(nn.Module):
hidden_states = self.cross_attn( hidden_states = self.cross_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=position_embeddings,
position_ids=position_ids,
glide_input=glide_input, glide_input=glide_input,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=True, use_cache=True,
) )
@ -441,9 +405,6 @@ class GlideLlamaDecoderLayer(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
return outputs return outputs

View File

@ -478,9 +478,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_oproj=attn_oproj, attn_oproj=attn_oproj,
process_group=process_group, process_group=process_group,
model_shard_infer_config=model_shard_infer_config, model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads, num_heads=module.config.num_attention_heads,
hidden_size=module.hidden_size, hidden_size=module.config.hidden_size,
num_key_value_heads=module.num_key_value_heads, num_key_value_heads=module.config.num_key_value_heads,
) )
return attn_layer return attn_layer

View File

@ -3,6 +3,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers.cache_utils import DynamicCache
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -93,9 +94,8 @@ class Drafter:
for _ in range(n_spec_tokens): for _ in range(n_spec_tokens):
# update past key values # update past key values
kwargs["past_key_values"] = past_key_values
outputs = self._drafter_model(input_ids, **kwargs) outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating. # NOTE Only use greedy search for speculating.
@ -114,6 +114,8 @@ class Drafter:
speculated_length = len(token_ids) # For now, only support bsz 1 speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0) logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1) token_ids = torch.concat(token_ids, dim=-1)
if isinstance(past_key_values, DynamicCache):
past_key_values = past_key_values.to_legacy_cache()
out = DrafterOutput( out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values

View File

@ -69,7 +69,7 @@ def new_from_pretrained(
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True) kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
@ -286,7 +286,8 @@ def new_from_pretrained(
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
# Instantiate model. # Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)] # init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts = [no_init_weights()]
with ContextManagers(init_contexts): with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)

View File

@ -1,6 +1,6 @@
import inspect import inspect
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -13,6 +13,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa,
) )
from transformers.models.mixtral.modeling_mixtral import ( from transformers.models.mixtral.modeling_mixtral import (
MixtralModel,
MixtralSparseMoeBlock, MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
@ -215,7 +216,7 @@ class MixtralPipelineForwards:
@staticmethod @staticmethod
def mixtral_model_forward( def mixtral_model_forward(
self, self: MixtralModel,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@ -225,6 +226,7 @@ class MixtralPipelineForwards:
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
@ -340,11 +342,17 @@ class MixtralPipelineForwards:
) )
use_cache = False use_cache = False
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None all_router_logits = () if output_router_logits else None
next_decoder_cache = None if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
@ -370,6 +378,9 @@ class MixtralPipelineForwards:
None, None,
output_attentions, output_attentions,
output_router_logits, output_router_logits,
use_cache,
cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -380,6 +391,8 @@ class MixtralPipelineForwards:
output_attentions, output_attentions,
output_router_logits, output_router_logits,
use_cache, use_cache,
cache_position,
position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -559,14 +572,18 @@ class MixtralPipelineForwards:
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.mixtral.modeling_mixtral import eager_attention_forward
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
@ -614,54 +631,23 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = position_embeddings
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size: if not _flash_supports_window_size:
logger.warning_once( logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation" "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library." " make sure to upgrade flash-attn library."
) )
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
@ -689,14 +675,27 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward(
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask,
q_len, dropout=0.0 if not self.training else self.attention_dropout,
dropout=dropout_rate, scaling=self.scaling,
use_sliding_windows=use_sliding_windows, sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
**kwargs,
) )
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
@ -712,7 +711,7 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights
return forward return forward
@ -731,6 +730,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]: ) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -788,7 +788,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
if self._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions: elif self._attn_implementation == "sdpa" and not output_attentions:
@ -820,6 +820,16 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@ -840,6 +850,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
output_attentions, output_attentions,
output_router_logits, output_router_logits,
use_cache, use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -850,6 +862,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
output_attentions=output_attentions, output_attentions=output_attentions,
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -40,21 +40,9 @@ class MixtralPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.mixtral.modeling_mixtral import ( from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel
MixtralAttention,
MixtralDecoderLayer,
MixtralFlashAttention2,
MixtralModel,
MixtralSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
"sdpa": MixtralSdpaAttention,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
@ -76,7 +64,7 @@ class MixtralPolicy(Policy):
num_kv_heads //= sp_size num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[MixtralAttention] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
) )
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
@ -89,7 +77,7 @@ class MixtralPolicy(Policy):
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
}, },
policy=policy, policy=policy,
target_key=attn_cls, target_key=MixtralAttention,
) )
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
@ -330,7 +318,7 @@ class MixtralPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) layers_per_stage = stage_manager.distribute_layers(len(module.layers))