mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[shardformer] upgrade transformers to 4.39.3 (#5815)
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -32,6 +32,7 @@ def _get_attention_mask(
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
use_flash_attention_2: bool = False,
|
||||
) -> Optional[Union[torch.Tensor, dict]]:
|
||||
batch_size, seq_len = hidden_states.shape[:2]
|
||||
past_key_values_length = 0
|
||||
@@ -47,7 +48,7 @@ def _get_attention_mask(
|
||||
attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
elif attention_mask is not None:
|
||||
elif use_flash_attention_2 and attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
@@ -162,7 +163,9 @@ class GPTJPipelineForwards:
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -419,7 +422,10 @@ class GPTJPipelineForwards:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
@@ -712,7 +718,9 @@ def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||
|
||||
@@ -886,7 +894,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
|
Reference in New Issue
Block a user