mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 08:47:17 +00:00
[upgrade]Upgrade mixtral (#6317)
* upgrade mixtral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade infer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade drafter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * upgrade lazy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade mixtral --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2aa295e959
commit
d0e13b85fd
@ -6,19 +6,16 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
LlamaDecoderLayer,
|
LlamaDecoderLayer,
|
||||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaLinearScalingRotaryEmbedding,
|
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaRMSNorm,
|
LlamaRMSNorm,
|
||||||
LlamaRotaryEmbedding,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.inference.spec import GlideInput
|
from colossalai.inference.spec import GlideInput
|
||||||
@ -156,15 +153,11 @@ def glide_llama_model_forward(
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
if use_cache and past_key_values is None:
|
||||||
if use_cache: # kept for BC (cache positions)
|
past_key_values = DynamicCache()
|
||||||
if not isinstance(past_key_values, StaticCache):
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
@ -172,15 +165,17 @@ def glide_llama_model_forward(
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||||
|
if hasattr(glide_input, "n_spec_tokens"):
|
||||||
|
position_ids = position_ids + glide_input.n_spec_tokens
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -189,9 +184,9 @@ def glide_llama_model_forward(
|
|||||||
# GlideLlamaDecoderLayer
|
# GlideLlamaDecoderLayer
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
glide_input=glide_input,
|
glide_input=glide_input,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@ -200,9 +195,6 @@ def glide_llama_model_forward(
|
|||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
@ -212,16 +204,11 @@ def glide_llama_model_forward(
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
|
||||||
if use_cache:
|
|
||||||
next_cache = (
|
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=past_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
@ -267,31 +254,6 @@ class LlamaCrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
|
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
|
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
|
||||||
self._init_rope()
|
|
||||||
|
|
||||||
def _init_rope(self):
|
|
||||||
if self.config.rope_scaling is None:
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(
|
|
||||||
self.large_head_dim,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
scaling_type = self.config.rope_scaling["type"]
|
|
||||||
scaling_factor = self.config.rope_scaling["factor"]
|
|
||||||
if scaling_type == "linear":
|
|
||||||
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
|
||||||
self.large_head_dim,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
scaling_factor=scaling_factor,
|
|
||||||
)
|
|
||||||
elif scaling_type == "dynamic":
|
|
||||||
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
|
||||||
self.large_head_dim,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
scaling_factor=scaling_factor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
@ -299,9 +261,10 @@ class LlamaCrossAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
|
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
@ -319,8 +282,7 @@ class LlamaCrossAttention(nn.Module):
|
|||||||
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# for RoPE
|
# for RoPE
|
||||||
position_ids = position_ids + glide_input.n_spec_tokens
|
cos, sin = position_embeddings
|
||||||
cos, sin = self.rotary_emb(query_states, position_ids)
|
|
||||||
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||||
@ -367,9 +329,10 @@ class GlideLlamaDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
glide_input: GlideInput = None,
|
glide_input: GlideInput = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
@ -399,10 +362,10 @@ class GlideLlamaDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@ -425,9 +388,10 @@ class GlideLlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states = self.cross_attn(
|
hidden_states = self.cross_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
position_ids=position_ids,
|
||||||
glide_input=glide_input,
|
glide_input=glide_input,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
@ -441,9 +405,6 @@ class GlideLlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -478,9 +478,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||||||
attn_oproj=attn_oproj,
|
attn_oproj=attn_oproj,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
model_shard_infer_config=model_shard_infer_config,
|
model_shard_infer_config=model_shard_infer_config,
|
||||||
num_heads=module.num_heads,
|
num_heads=module.config.num_attention_heads,
|
||||||
hidden_size=module.hidden_size,
|
hidden_size=module.config.hidden_size,
|
||||||
num_key_value_heads=module.num_key_value_heads,
|
num_key_value_heads=module.config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_layer
|
return attn_layer
|
||||||
|
@ -3,6 +3,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -93,9 +94,8 @@ class Drafter:
|
|||||||
|
|
||||||
for _ in range(n_spec_tokens):
|
for _ in range(n_spec_tokens):
|
||||||
# update past key values
|
# update past key values
|
||||||
kwargs["past_key_values"] = past_key_values
|
|
||||||
|
|
||||||
outputs = self._drafter_model(input_ids, **kwargs)
|
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
# NOTE Only use greedy search for speculating.
|
# NOTE Only use greedy search for speculating.
|
||||||
@ -114,6 +114,8 @@ class Drafter:
|
|||||||
speculated_length = len(token_ids) # For now, only support bsz 1
|
speculated_length = len(token_ids) # For now, only support bsz 1
|
||||||
logits = torch.concat(logits, dim=0)
|
logits = torch.concat(logits, dim=0)
|
||||||
token_ids = torch.concat(token_ids, dim=-1)
|
token_ids = torch.concat(token_ids, dim=-1)
|
||||||
|
if isinstance(past_key_values, DynamicCache):
|
||||||
|
past_key_values = past_key_values.to_legacy_cache()
|
||||||
|
|
||||||
out = DrafterOutput(
|
out = DrafterOutput(
|
||||||
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
|
||||||
|
@ -69,7 +69,7 @@ def new_from_pretrained(
|
|||||||
_ = kwargs.pop("mirror", None)
|
_ = kwargs.pop("mirror", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
_fast_init = kwargs.pop("_fast_init", True)
|
kwargs.pop("_fast_init", True)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
@ -286,7 +286,8 @@ def new_from_pretrained(
|
|||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
# init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||||
|
init_contexts = [no_init_weights()]
|
||||||
|
|
||||||
with ContextManagers(init_contexts):
|
with ContextManagers(init_contexts):
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -13,6 +13,7 @@ from transformers.modeling_attn_mask_utils import (
|
|||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
)
|
)
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
|
MixtralModel,
|
||||||
MixtralSparseMoeBlock,
|
MixtralSparseMoeBlock,
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
@ -215,7 +216,7 @@ class MixtralPipelineForwards:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mixtral_model_forward(
|
def mixtral_model_forward(
|
||||||
self,
|
self: MixtralModel,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -225,6 +226,7 @@ class MixtralPipelineForwards:
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
@ -340,11 +342,17 @@ class MixtralPipelineForwards:
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_router_logits = () if output_router_logits else None
|
all_router_logits = () if output_router_logits else None
|
||||||
next_decoder_cache = None
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
@ -370,6 +378,9 @@ class MixtralPipelineForwards:
|
|||||||
None,
|
None,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_router_logits,
|
output_router_logits,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -380,6 +391,8 @@ class MixtralPipelineForwards:
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
output_router_logits,
|
output_router_logits,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -559,14 +572,18 @@ class MixtralPipelineForwards:
|
|||||||
|
|
||||||
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import eager_attention_forward
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
@ -614,54 +631,23 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
cos, sin = position_embeddings
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
use_sliding_windows = (
|
|
||||||
_flash_supports_window_size
|
|
||||||
and getattr(self.config, "sliding_window", None) is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
)
|
|
||||||
if not _flash_supports_window_size:
|
if not _flash_supports_window_size:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
||||||
" make sure to upgrade flash-attn library."
|
" make sure to upgrade flash-attn library."
|
||||||
)
|
)
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
|
||||||
if (
|
|
||||||
getattr(self.config, "sliding_window", None) is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
and cache_has_contents
|
|
||||||
):
|
|
||||||
slicing_tokens = 1 - self.config.sliding_window
|
|
||||||
|
|
||||||
past_key = past_key_value[self.layer_idx][0]
|
|
||||||
past_value = past_key_value[self.layer_idx][1]
|
|
||||||
|
|
||||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
|
|
||||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
|
||||||
f" {past_key.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[:, slicing_tokens:]
|
|
||||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
|
||||||
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
0.0 if not self.training else self.attention_dropout
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
@ -689,14 +675,27 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
attn_output = self._flash_attention_forward(
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
|
logger.warning_once(
|
||||||
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||||
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
q_len,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
dropout=dropout_rate,
|
scaling=self.scaling,
|
||||||
use_sliding_windows=use_sliding_windows,
|
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
@ -712,7 +711,7 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
@ -731,6 +730,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
@ -788,7 +788,7 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||||
)
|
)
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
@ -820,6 +820,16 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||||||
)
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
@ -840,6 +850,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
output_router_logits,
|
output_router_logits,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -850,6 +862,8 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_router_logits=output_router_logits,
|
output_router_logits=output_router_logits,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
@ -40,21 +40,9 @@ class MixtralPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
from transformers.models.mixtral.modeling_mixtral import MixtralAttention, MixtralDecoderLayer, MixtralModel
|
||||||
MixtralAttention,
|
|
||||||
MixtralDecoderLayer,
|
|
||||||
MixtralFlashAttention2,
|
|
||||||
MixtralModel,
|
|
||||||
MixtralSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
|
||||||
"eager": MixtralAttention,
|
|
||||||
"flash_attention_2": MixtralFlashAttention2,
|
|
||||||
"sdpa": MixtralSdpaAttention,
|
|
||||||
}
|
|
||||||
policy = {}
|
policy = {}
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
@ -76,7 +64,7 @@ class MixtralPolicy(Policy):
|
|||||||
num_kv_heads //= sp_size
|
num_kv_heads //= sp_size
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[MixtralAttention] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
@ -89,7 +77,7 @@ class MixtralPolicy(Policy):
|
|||||||
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=MixtralAttention,
|
||||||
)
|
)
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
@ -330,7 +318,7 @@ class MixtralPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
|
held_layers.append(module.rotary_emb)
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
Loading…
Reference in New Issue
Block a user