mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_sp_output,
|
||||
@@ -25,42 +25,7 @@ def get_flash_core_attention_forward():
|
||||
|
||||
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
|
||||
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
attention_mask_type = AttnMaskType.CAUSAL
|
||||
attn_bias = torch.zeros(
|
||||
query_layer.shape[0],
|
||||
1,
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=query_layer.dtype,
|
||||
device=query_layer.device,
|
||||
)
|
||||
temp_mask = (
|
||||
torch.ones(
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=torch.bool,
|
||||
device=query_layer.device,
|
||||
)
|
||||
.tril(diagonal=0)
|
||||
.expand(query_layer.shape[0], 1, -1, -1)
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
|
||||
else:
|
||||
attention_mask_type = AttnMaskType.CUSTOM
|
||||
if attention_mask is not None:
|
||||
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
|
||||
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
|
||||
dropout_p = self.attention_dropout.p if self.training else 0.0
|
||||
context_layer = ColoAttention.attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask=attn_bias,
|
||||
attention_mask_type=attention_mask_type,
|
||||
dropout_p=dropout_p,
|
||||
scale=1.0 / self.norm_factor,
|
||||
)
|
||||
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
@@ -180,9 +145,20 @@ class ChatGLMPipelineForwards:
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Support SP + PP
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
@@ -237,7 +213,7 @@ class ChatGLMPipelineForwards:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb,
|
||||
past_key_values[idx],
|
||||
use_cache,
|
||||
@@ -402,10 +378,19 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
@@ -652,3 +637,79 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
||||
return output, kv_cache
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_flash_attention_forward_for_chat_glm_model():
|
||||
from .chatglm2_6b.modeling_chatglm import ChatGLMModel
|
||||
|
||||
def forward(
|
||||
self: ChatGLMModel,
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
|
||||
)
|
||||
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
full_attention_mask: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
if position_ids is not None:
|
||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||
else:
|
||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
# Run encoder.
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
kv_caches=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user