mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
[upgrade]upgrade mistral (#6296)
* upgrade mistral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
04516bb756
commit
6875a8a1cf
@ -4,10 +4,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@ -36,7 +32,7 @@ class MistralForwards:
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
@ -50,8 +46,6 @@ class MistralForwards:
|
|||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
@ -67,20 +61,23 @@ class MistralForwards:
|
|||||||
else:
|
else:
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = hidden_states.device
|
hidden_states.device
|
||||||
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if position_ids is None:
|
if use_cache and past_key_values is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
past_key_values = DynamicCache()
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
|
|
||||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache:
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -100,27 +97,9 @@ class MistralForwards:
|
|||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
attention_mask = self._update_causal_mask(
|
||||||
# 2d mask is passed through the layers
|
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
)
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
hidden_states,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@ -133,6 +112,8 @@ class MistralForwards:
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
num_ckpt_layers = 0
|
num_ckpt_layers = 0
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
@ -156,11 +137,13 @@ class MistralForwards:
|
|||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -170,6 +153,8 @@ class MistralForwards:
|
|||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@ -189,8 +174,6 @@ class MistralForwards:
|
|||||||
|
|
||||||
next_cache = None
|
next_cache = None
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
@ -212,7 +195,8 @@ class MistralForwards:
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
@ -248,7 +232,6 @@ class MistralForwards:
|
|||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = MistralForwards.mistral_model_forward(
|
outputs = MistralForwards.mistral_model_forward(
|
||||||
@ -261,7 +244,7 @@ class MistralForwards:
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
cache_position=cache_position,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
@ -278,10 +261,6 @@ class MistralForwards:
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
return CausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
@ -305,7 +284,6 @@ class MistralForwards:
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
@ -317,7 +295,6 @@ class MistralForwards:
|
|||||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
transformer_outputs = MistralForwards.mistral_model_forward(
|
transformer_outputs = MistralForwards.mistral_model_forward(
|
||||||
self.model,
|
self.model,
|
||||||
@ -329,7 +306,6 @@ class MistralForwards:
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
@ -383,9 +359,6 @@ class MistralForwards:
|
|||||||
elif self.config.problem_type == "multi_label_classification":
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
loss_fct = BCEWithLogitsLoss()
|
loss_fct = BCEWithLogitsLoss()
|
||||||
loss = loss_fct(pooled_logits, labels)
|
loss = loss_fct(pooled_logits, labels)
|
||||||
if not return_dict:
|
|
||||||
output = (pooled_logits,) + transformer_outputs[1:]
|
|
||||||
return ((loss,) + output) if loss is not None else output
|
|
||||||
else:
|
else:
|
||||||
hidden_states = transformer_outputs.get("hidden_states")
|
hidden_states = transformer_outputs.get("hidden_states")
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
@ -413,7 +386,8 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**flash_attn_kwargs,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -421,8 +395,6 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
@ -433,27 +405,22 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
|
||||||
if use_legacy_cache:
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
if use_cache and past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache:
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -471,31 +438,11 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||||||
q_padding_mask=attention_mask,
|
q_padding_mask=attention_mask,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if self._attn_implementation == "flash_attention_2":
|
|
||||||
# 2d mask is passed through the layers
|
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -506,37 +453,25 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
decoder_layer.__call__,
|
attention_mask=attention_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
attention_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
)
|
**flash_attn_kwargs,
|
||||||
else:
|
)
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
@ -546,15 +481,12 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
# add hidden states from the last decoder layer
|
||||||
if use_cache:
|
if output_hidden_states:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
@ -568,11 +500,10 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
|||||||
def forward(
|
def forward(
|
||||||
self: MistralAttention,
|
self: MistralAttention,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
@ -585,9 +516,9 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -598,11 +529,12 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
|||||||
"with a layer index."
|
"with a layer index."
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
@ -613,11 +545,11 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
|||||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
@ -38,24 +38,10 @@ class MistralPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
|
||||||
MistralAttention,
|
|
||||||
MistralDecoderLayer,
|
|
||||||
MistralFlashAttention2,
|
|
||||||
MistralModel,
|
|
||||||
MistralSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
|
||||||
"eager": MistralAttention,
|
|
||||||
"flash_attention_2": MistralFlashAttention2,
|
|
||||||
"sdpa": MistralSdpaAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
|
|
||||||
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = VocabParallelEmbedding1D
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
@ -258,7 +244,7 @@ class MistralPolicy(Policy):
|
|||||||
"forward": get_mistral_flash_attention_forward(self.shard_config),
|
"forward": get_mistral_flash_attention_forward(self.shard_config),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=MistralAttention,
|
||||||
)
|
)
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
# replace llama model forward method
|
# replace llama model forward method
|
||||||
@ -316,6 +302,7 @@ class MistralPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
|
held_layers.append(module.rotary_emb)
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
@ -23,6 +23,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
model_fn, loss_fn, test_config
|
model_fn, loss_fn, test_config
|
||||||
@ -176,7 +177,6 @@ def check_mistral(rank, world_size, port):
|
|||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
|
||||||
def test_mistral():
|
def test_mistral():
|
||||||
spawn(check_mistral, 4)
|
spawn(check_mistral, 4)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user