mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batch
This commit is contained in:
@@ -7,7 +7,6 @@ from torch import Tensor
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
from ..modeling.chatglm2 import (
|
||||
get_chatglm_sequence_parallel_forward_fn,
|
||||
@@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
|
||||
__all__ = [
|
||||
"ChatGLMPolicy",
|
||||
"ChatGLMModelPolicy",
|
||||
"ChatGLMForConditionalGenerationPolicy",
|
||||
]
|
||||
|
||||
|
||||
class ChatGLMPolicy(Policy):
|
||||
@@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
|
||||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
@@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GLMBlock] = ModulePolicyDescription(
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
|
||||
attn_kwargs = {
|
||||
"self_attention.qkv_hidden_size": (
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
|
||||
)
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if self.model.config.multi_query_attention:
|
||||
assert (
|
||||
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
|
||||
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
|
||||
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
|
||||
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
attn_kwargs["self_attention.qkv_hidden_size"] = (
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads
|
||||
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
|
||||
) // self.shard_config.tensor_parallel_size
|
||||
policy["GLMBlock"] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
@@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads
|
||||
)
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.qkv_hidden_size": (
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
|
||||
)
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
|
||||
* self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
**attn_kwargs,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"overlap": overlap,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
@@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
@@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
target_key="GLMBlock",
|
||||
)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
@@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
@@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
|
||||
"forward": get_flash_core_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CoreAttention,
|
||||
target_key="CoreAttention",
|
||||
)
|
||||
|
||||
# use sequence parallel
|
||||
@@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use jit fused operator
|
||||
@@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
target_key="GLMBlock",
|
||||
)
|
||||
|
||||
return policy
|
||||
@@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config,
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
@@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
|
||||
model_cls="ChatGLMModel",
|
||||
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
@@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=ChatGLMForConditionalGeneration,
|
||||
model_cls="ChatGLMForConditionalGeneration",
|
||||
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
Reference in New Issue
Block a user