mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[shardformer] upgrade transformers to 4.39.3 (#5815)
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
_HIDDEN_STATES_START_POSITION,
|
||||
WhisperDecoder,
|
||||
WhisperEncoder,
|
||||
WhisperForAudioClassification,
|
||||
@@ -166,6 +167,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
position_ids=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@@ -199,9 +201,13 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
@@ -599,6 +605,7 @@ class WhisperPipelineForwards:
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
position_ids=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@@ -716,9 +723,13 @@ class WhisperPipelineForwards:
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
positions = self.embed_positions(
|
||||
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
@@ -841,6 +852,7 @@ class WhisperPipelineForwards:
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@@ -944,6 +956,7 @@ class WhisperPipelineForwards:
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
position_ids=decoder_position_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -986,6 +999,7 @@ class WhisperPipelineForwards:
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@@ -1048,6 +1062,7 @@ class WhisperPipelineForwards:
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
decoder_position_ids=decoder_position_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -1118,6 +1133,12 @@ class WhisperPipelineForwards:
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
output_hidden_states = True
|
||||
elif output_hidden_states is None:
|
||||
output_hidden_states = self.config.output_hidden_states
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# audio_classification only holds encoder
|
||||
@@ -1138,7 +1159,8 @@ class WhisperPipelineForwards:
|
||||
return encoder_outputs
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = torch.stack(encoder_outputs, dim=1)
|
||||
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user