mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
parent
5374601741
commit
2aa295e959
@ -128,7 +128,7 @@ class OPTPipelineForwards:
|
|||||||
# required mask seq length can be calculated via length of past
|
# required mask seq length can be calculated via length of past
|
||||||
mask_seq_length = past_key_values_length + seq_length
|
mask_seq_length = past_key_values_length + seq_length
|
||||||
# embed positions
|
# embed positions
|
||||||
if self.decoder._use_flash_attention_2:
|
if self.decoder.config._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
@ -542,6 +542,9 @@ class OPTPipelineForwards:
|
|||||||
def get_opt_flash_attention_forward(shard_config: ShardConfig):
|
def get_opt_flash_attention_forward(shard_config: ShardConfig):
|
||||||
from transformers.models.opt.modeling_opt import OPTAttention
|
from transformers.models.opt.modeling_opt import OPTAttention
|
||||||
|
|
||||||
|
def _shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
||||||
|
return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self: OPTAttention,
|
self: OPTAttention,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -568,30 +571,20 @@ def get_opt_flash_attention_forward(shard_config: ShardConfig):
|
|||||||
value_states = past_key_value[1]
|
value_states = past_key_value[1]
|
||||||
elif is_cross_attention:
|
elif is_cross_attention:
|
||||||
# cross_attentions
|
# cross_attentions
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
key_states = _shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
value_states = _shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||||
elif past_key_value is not None:
|
elif past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
else:
|
else:
|
||||||
# self_attention
|
# self_attention
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
key_states = _shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
value_states = _shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if self.is_decoder:
|
query_states = _shape(query_states, tgt_len, bsz, self.num_heads, self.head_dim)
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
query_states = self._shape(query_states, tgt_len, bsz)
|
|
||||||
|
|
||||||
dropout_p = self.dropout if self.training else 0.0
|
dropout_p = self.dropout if self.training else 0.0
|
||||||
attn_output = ColoAttention.attention(
|
attn_output = ColoAttention.attention(
|
||||||
@ -630,6 +623,8 @@ def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.Tensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -53,6 +53,7 @@ config = transformers.OPTConfig(
|
|||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
|
attn_implementation="eager",
|
||||||
)
|
)
|
||||||
|
|
||||||
# register the following models
|
# register the following models
|
||||||
|
Loading…
Reference in New Issue
Block a user