mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -15,7 +15,9 @@ from transformers.utils import logging
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
@@ -105,18 +107,25 @@ class LlamaPipelineForwards:
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
if LATEST_VERSION:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
if LATEST_VERSION:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -262,6 +271,7 @@ class LlamaPipelineForwards:
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
@@ -352,6 +362,7 @@ class LlamaPipelineForwards:
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
@@ -420,8 +431,6 @@ class LlamaPipelineForwards:
|
||||
def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
@@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[dict] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
@@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
if not hasattr(self, "attention"):
|
||||
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = self.attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type,
|
||||
origin_attn_mask=attention_mask,
|
||||
)
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
@@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
|
Reference in New Issue
Block a user