mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,26 +1,19 @@
|
||||
""" PyTorch ChatGLM model. """
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
ChatGLMForConditionalGeneration,
|
||||
ChatGLMModel,
|
||||
GLMBlock,
|
||||
)
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||
@@ -30,15 +23,15 @@ def get_flash_core_attention_forward():
|
||||
if pytorch_major_version >= 2:
|
||||
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
is_causal=True)
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, is_causal=True
|
||||
)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
attention_mask = ~attention_mask
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
||||
attention_mask)
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attention_mask
|
||||
)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
@@ -60,15 +53,15 @@ def get_flash_core_attention_forward():
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
dropout=self.attention_dropout.p,
|
||||
scale=scale)
|
||||
context_layer = attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type)
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.hidden_size_per_partition,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
dropout=self.attention_dropout.p,
|
||||
scale=scale,
|
||||
)
|
||||
context_layer = attention(
|
||||
query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(1, 0, -1).contiguous()
|
||||
|
||||
@@ -78,7 +71,6 @@ def get_flash_core_attention_forward():
|
||||
|
||||
|
||||
def get_jit_fused_glm_block_forward():
|
||||
|
||||
from .chatglm2_6b.modeling_chatglm import GLMBlock
|
||||
|
||||
def forward(
|
||||
@@ -129,9 +121,9 @@ def get_jit_fused_glm_block_forward():
|
||||
|
||||
|
||||
class ChatGLMPipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
|
||||
'''
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def chatglm_model_forward(
|
||||
@@ -151,19 +143,20 @@ class ChatGLMPipelineForwards:
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
||||
past_key_values = None
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
if stage_manager.is_first_stage():
|
||||
batch_size, seq_length = input_ids.shape
|
||||
@@ -174,12 +167,13 @@ class ChatGLMPipelineForwards:
|
||||
seq_length, batch_size = hidden_states.shape[:2]
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
past_key_values = self.get_prompt(batch_size=batch_size,
|
||||
device=input_ids.device,
|
||||
dtype=inputs_embeds.dtype)
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask],
|
||||
dim=-1)
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
|
||||
)
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
@@ -196,37 +190,41 @@ class ChatGLMPipelineForwards:
|
||||
if self.encoder.gradient_checkpointing and self.encoder.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
all_self_attentions = None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if self.encoder.gradient_checkpointing and self.encoder.training:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb,
|
||||
past_key_values[idx], use_cache)
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache
|
||||
)
|
||||
else:
|
||||
layer_ret = layer(hidden_states,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=past_key_values[idx],
|
||||
use_cache=use_cache)
|
||||
layer_ret = layer(
|
||||
hidden_states,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=past_key_values[idx],
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states, kv_cache = layer_ret
|
||||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if stage_manager.is_last_stage():
|
||||
@@ -235,7 +233,8 @@ class ChatGLMPipelineForwards:
|
||||
hidden_states = self.encoder.final_layernorm(hidden_states)
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
@@ -243,28 +242,30 @@ class ChatGLMPipelineForwards:
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
else:
|
||||
return {'hidden_states': hidden_states}
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None):
|
||||
logger = logging.get_logger(__name__)
|
||||
def chatglm_for_conditional_generation_forward(
|
||||
self: ChatGLMForConditionalGeneration,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
logging.get_logger(__name__)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward(
|
||||
self.transformer,
|
||||
input_ids=input_ids,
|
||||
@@ -312,7 +313,6 @@ class ChatGLMPipelineForwards:
|
||||
|
||||
|
||||
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -325,10 +325,11 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
@@ -365,9 +366,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
# Run encoder.
|
||||
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
@@ -377,17 +378,21 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
|
Reference in New Issue
Block a user