mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batch
This commit is contained in:
@@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
@@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
|
||||
device=query_layer.device,
|
||||
)
|
||||
temp_mask = (
|
||||
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
|
||||
torch.ones(
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=torch.bool,
|
||||
device=query_layer.device,
|
||||
)
|
||||
.tril(diagonal=0)
|
||||
.expand(query_layer.shape[0], 1, -1, -1)
|
||||
)
|
||||
@@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
|
||||
attention_mask=attn_bias,
|
||||
attention_mask_type=attention_mask_type,
|
||||
dropout_p=dropout_p,
|
||||
scale=1.0 / self.norm_factor,
|
||||
)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
@@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def chatglm_model_forward(
|
||||
self: ChatGLMModel,
|
||||
self: "ChatGLMModel",
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
@@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
@@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
|
||||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def chatglm_for_conditional_generation_forward(
|
||||
self: ChatGLMForConditionalGeneration,
|
||||
self: "ChatGLMForConditionalGeneration",
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
Reference in New Issue
Block a user