mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable * polish code
This commit is contained in:
@@ -42,116 +42,126 @@ class T5BasePolicy(Policy):
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
base_policy = {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5LayerCrossAttention:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5Attention:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"n_heads":
|
||||
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
ignore_if_not_exist=True)
|
||||
]),
|
||||
T5LayerFF:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5DenseGatedActDense:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5DenseActDense:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
])
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"n_heads":
|
||||
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
ignore_if_not_exist=True)
|
||||
])
|
||||
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[T5LayerFF].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
||||
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
||||
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
||||
base_policy[T5Stack].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
|
||||
|
||||
return base_policy
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||
@@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=T5Model)
|
||||
return base_policy
|
||||
|
||||
|
||||
@@ -183,14 +194,19 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
],
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
@@ -212,12 +228,14 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=T5EncoderModel)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
|
Reference in New Issue
Block a user