[shardformer] fix chatglm implementation (#5644)

* [shardformer] fix chatglm policy

* [shardformer] fix chatglm flash attn

* [shardformer] update readme

* [shardformer] fix chatglm init

* [shardformer] fix chatglm test

* [pipeline] fix chatglm merge batch
This commit is contained in:
Hongxin Liu
2024-04-25 14:41:17 +08:00
committed by GitHub
parent 5d88ef1aaf
commit bbb2c21f16
11 changed files with 193 additions and 117 deletions

View File

@@ -7,7 +7,6 @@ from torch import Tensor
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn,
@@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
__all__ = [
"ChatGLMPolicy",
"ChatGLMModelPolicy",
"ChatGLMForConditionalGenerationPolicy",
]
class ChatGLMPolicy(Policy):
@@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {}
embedding_cls = None
@@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock] = ModulePolicyDescription(
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs = {
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
}
if self.model.config.multi_query_attention:
assert (
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
attn_kwargs["self_attention.qkv_hidden_size"] = (
self.model.config.kv_channels * self.model.config.num_attention_heads
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
) // self.shard_config.tensor_parallel_size
policy["GLMBlock"] = ModulePolicyDescription(
attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
@@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
self.model.config.kv_channels * self.model.config.num_attention_heads
)
// self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
**attn_kwargs,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
@@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
),
],
policy=policy,
target_key=ChatGLMModel,
target_key="ChatGLMModel",
)
# optimization configuration
self.append_or_create_submodule_replacement(
@@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
),
],
policy=policy,
target_key=GLMBlock,
target_key="GLMBlock",
)
if self.model.config.post_layer_norm:
@@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
)
],
policy=policy,
target_key=ChatGLMModel,
target_key="ChatGLMModel",
)
# use flash attention
@@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
"forward": get_flash_core_attention_forward(),
},
policy=policy,
target_key=CoreAttention,
target_key="CoreAttention",
)
# use sequence parallel
@@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=ChatGLMModel,
target_key="ChatGLMModel",
)
# use jit fused operator
@@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=GLMBlock,
target_key="GLMBlock",
)
return policy
@@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
model_cls="ChatGLMModel",
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
policy=policy,
)
return policy
@@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration,
model_cls="ChatGLMForConditionalGeneration",
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy,
)