mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
parent
5374601741
commit
2aa295e959
@ -128,7 +128,7 @@ class OPTPipelineForwards:
|
||||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
# embed positions
|
||||
if self.decoder._use_flash_attention_2:
|
||||
if self.decoder.config._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
@ -542,6 +542,9 @@ class OPTPipelineForwards:
|
||||
def get_opt_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
|
||||
def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
||||
return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self: OPTAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -568,30 +571,20 @@ def get_opt_flash_attention_forward(shard_config: ShardConfig):
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
attn_output = ColoAttention.attention(
|
||||
@ -630,6 +623,8 @@ def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -53,6 +53,7 @@ config = transformers.OPTConfig(
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
dropout=0,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
# register the following models
|
||||
|
Loading…
Reference in New Issue
Block a user