[shardformer] made tensor parallelism configurable (#4144)

* [shardformer] made tensor parallelism configurable

* polish code
This commit is contained in:
Frank Lee
2023-07-04 09:57:03 +08:00
parent 74257cb446
commit 1fb0d95df0
15 changed files with 819 additions and 673 deletions

View File

@@ -42,116 +42,126 @@ class T5BasePolicy(Policy):
T5Stack,
)
base_policy = {
T5Stack:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
]),
T5LayerSelfAttention:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
T5LayerCrossAttention:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
T5Attention:
ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
]),
T5LayerFF:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
T5DenseGatedActDense:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
T5DenseActDense:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
}
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
])
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[T5LayerFF].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5Stack].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
return base_policy
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
return policy
def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
@@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy):
def module_policy(self):
from transformers import T5Model
base_policy = super().module_policy()
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
),
policy=base_policy,
target_key=T5Model)
return base_policy
@@ -183,14 +194,19 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
],
policy=policy,
target_key=T5ForConditionalGeneration)
return policy
def postprocess(self):
@@ -212,12 +228,14 @@ class T5EncoderPolicy(T5BasePolicy):
from transformers import T5EncoderModel
base_policy = super().module_policy()
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
),
policy=base_policy,
target_key=T5EncoderModel)
return base_policy
def postprocess(self):