[shardformer] made tensor parallelism configurable (#4144)

* [shardformer] made tensor parallelism configurable

* polish code
This commit is contained in:
Frank Lee
2023-07-04 09:57:03 +08:00
parent 74257cb446
commit 1fb0d95df0
15 changed files with 819 additions and 673 deletions

View File

@@ -28,58 +28,58 @@ class LlamaPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
],
),
LlamaModel:
ModulePolicyDescription(sub_module_replacement=[
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
])
}
],
)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=LlamaModel)
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
@@ -88,15 +88,18 @@ class LlamaPolicy(Policy):
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
)
])
],
policy=policy,
target_key=LlamaDecoderLayer)
base_policy[LlamaModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
))
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=LlamaModel)
return base_policy
return policy
def postprocess(self):
return self.model
@@ -108,15 +111,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
from transformers import LlamaForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@@ -127,13 +132,14 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
policy = super().module_policy()
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy