mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable * polish code
This commit is contained in:
@@ -29,66 +29,67 @@ class OPTPolicy(Policy):
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||
|
||||
base_policy = {
|
||||
OPTDecoder:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
]),
|
||||
OPTDecoderLayer:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
]),
|
||||
OPTAttention:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="out_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
]),
|
||||
}
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
])
|
||||
|
||||
policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={
|
||||
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="out_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[OPTDecoder].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True))
|
||||
base_policy[OPTDecoderLayer].sub_module_replacement.extend([
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder)
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True),
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True)
|
||||
])
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer)
|
||||
|
||||
return base_policy
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
@@ -106,15 +107,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
new_item = {
|
||||
OPTForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
|
||||
policy.update(new_item)
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
Reference in New Issue
Block a user