mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable * polish code
This commit is contained in:
@@ -33,89 +33,114 @@ class BertPolicy(Policy):
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
||||
|
||||
base_policy = {
|
||||
BertLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
BertEmbeddings:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attention.self.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
|
||||
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[BertLayer].sub_module_replacement.append(
|
||||
# Handle bert layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
base_policy[BertLayer].sub_module_replacement.append(
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
base_policy[BertEmbeddings].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer)
|
||||
|
||||
# handle embedding layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[SubModuleReplacementDescription(
|
||||
suffix="LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),)
|
||||
)],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings)
|
||||
return policy
|
||||
|
||||
def add_lm_head_policy(self, base_policy):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
# optimize for tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
|
||||
# optimize with fused normalization
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle bert lm prediction head
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
@@ -136,35 +161,14 @@ class BertForPretrainingPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
# append extra policy
|
||||
module_policy.update(addon_module)
|
||||
module_policy = self.add_lm_head_policy(module_policy)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
@@ -176,31 +180,14 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
module_policy.update(addon_module)
|
||||
module_policy = self.add_lm_head_policy(module_policy)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
@@ -212,34 +199,14 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
module_policy.update(addon_module)
|
||||
module_policy = self.add_lm_head_policy(module_policy)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
@@ -254,16 +221,18 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
@@ -277,16 +246,18 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
|
||||
@@ -307,14 +278,16 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
Reference in New Issue
Block a user