[upgrade]upgrade opt (#6307)

* upgrade opt

* fix
This commit is contained in:
flybird11111 2025-05-21 16:13:32 +08:00 committed by GitHub
parent 5374601741
commit 2aa295e959
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 18 deletions

View File

@ -128,7 +128,7 @@ class OPTPipelineForwards:
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if self.decoder._use_flash_attention_2:
if self.decoder.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
@ -542,6 +542,9 @@ class OPTPipelineForwards:
def get_opt_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTAttention
def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()
def forward(
self: OPTAttention,
hidden_states: torch.Tensor,
@ -568,30 +571,20 @@ def get_opt_flash_attention_forward(shard_config: ShardConfig):
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim)
dropout_p = self.dropout if self.training else 0.0
attn_output = ColoAttention.attention(
@ -630,6 +623,8 @@ def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (

View File

@ -53,6 +53,7 @@ config = transformers.OPTConfig(
num_hidden_layers=2,
num_attention_heads=4,
dropout=0,
attn_implementation="eager",
)
# register the following models