[shardformer] made tensor parallelism configurable (#4144)

* [shardformer] made tensor parallelism configurable

* polish code
This commit is contained in:
Frank Lee
2023-07-04 09:57:03 +08:00
parent 74257cb446
commit 1fb0d95df0
15 changed files with 819 additions and 673 deletions

View File

@@ -126,3 +126,28 @@ class Policy(ABC):
the classifier layer
"""
pass
def append_or_create_submodule_replacement(
self, description: Union[SubModuleReplacementDescription,
List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module],
ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new submodule replacement description to the policy for the given key.
Args:
submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
# convert to list
if isinstance(description, SubModuleReplacementDescription):
description = [description]
# append or create a new description
if target_key in policy:
policy[target_key].sub_module_replacement.extend(description)
else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
return policy

View File

@@ -33,89 +33,114 @@ class BertPolicy(Policy):
def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
base_policy = {
BertLayer:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"attention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
)
]),
BertEmbeddings:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
}
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BertLayer].sub_module_replacement.append(
# Handle bert layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
base_policy[BertLayer].sub_module_replacement.append(
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription(
)
],
policy=policy,
target_key=BertLayer)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm,
),)
)],
policy=policy,
target_key=BertEmbeddings)
return policy
def add_lm_head_policy(self, base_policy):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
# optimize for tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
policy=base_policy,
target_key=BertLMPredictionHead)
# optimize with fused normalization
if self.shard_config.enable_fused_normalization:
# Handle bert lm prediction head
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
policy=base_policy,
target_key=BertLMPredictionHead)
return base_policy
def postprocess(self):
@@ -136,35 +161,14 @@ class BertForPretrainingPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
# append extra policy
module_policy.update(addon_module)
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@@ -176,31 +180,14 @@ class BertLMHeadModelPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@@ -212,34 +199,14 @@ class BertForMaskedLMPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@@ -254,16 +221,18 @@ class BertForSequenceClassificationPolicy(BertPolicy):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
@@ -277,16 +246,18 @@ class BertForTokenClassificationPolicy(BertPolicy):
from transformers.models.bert.modeling_bert import BertForTokenClassification
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
@@ -307,14 +278,16 @@ class BertForMultipleChoicePolicy(BertPolicy):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy

View File

@@ -85,57 +85,53 @@ class BloomPolicy(Policy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
base_policy = {
BloomBlock:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"self_attention.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
]),
BloomModel:
ModulePolicyDescription(attribute_replacement={
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
])
policy[BloomModel] = ModulePolicyDescription(
attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
}
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([
# handle bloom model
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
@@ -144,8 +140,12 @@ class BloomPolicy(Policy):
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
base_policy[BloomBlock].sub_module_replacement.extend([
],
policy=policy,
target_key=BloomModel)
# handle bloom block
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
@@ -154,9 +154,11 @@ class BloomPolicy(Policy):
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
],
policy=policy,
target_key=BloomBlock)
return base_policy
return policy
def postprocess(self):
return self.model
@@ -171,19 +173,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForCausalLM)
return policy
def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
@@ -191,7 +193,6 @@ class BloomForCausalLMPolicy(BloomPolicy):
param = nn.Parameter(param)
# tie weights
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForSequenceClassification)
return policy
@@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
}
policy.update(new_item)
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=BloomForTokenClassification)
return policy

View File

@@ -31,67 +31,67 @@ class GPT2Policy(Policy):
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
base_policy = {
GPT2Model:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
]),
GPT2Block:
ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
}
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
])
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[GPT2Model].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
))
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
policy=policy,
target_key=GPT2Model)
base_policy[GPT2Block].sub_module_replacement.extend([
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
@@ -103,9 +103,10 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription(suffix="ln_cross_attn",
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
])
return base_policy
],
policy=policy,
target_key=GPT2Block)
return policy
def postprocess(self):
return self.model
@@ -128,22 +129,22 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@@ -158,22 +159,22 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model

View File

@@ -28,58 +28,58 @@ class LlamaPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
],
),
LlamaModel:
ModulePolicyDescription(sub_module_replacement=[
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
])
}
],
)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=LlamaModel)
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
@@ -88,15 +88,18 @@ class LlamaPolicy(Policy):
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
)
])
],
policy=policy,
target_key=LlamaDecoderLayer)
base_policy[LlamaModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
))
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=LlamaModel)
return base_policy
return policy
def postprocess(self):
return self.model
@@ -108,15 +111,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
from transformers import LlamaForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@@ -127,13 +132,14 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
policy = super().module_policy()
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy

View File

@@ -29,66 +29,67 @@ class OPTPolicy(Policy):
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
base_policy = {
OPTDecoder:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]),
OPTDecoderLayer:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
]),
OPTAttention:
ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
),
]),
}
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
])
policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[OPTDecoder].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True))
base_policy[OPTDecoderLayer].sub_module_replacement.extend([
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True),
policy=policy,
target_key=OPTDecoder)
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True),
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True)
])
],
policy=policy,
target_key=OPTDecoderLayer)
return base_policy
return policy
def postprocess(self):
return self.model
@@ -106,15 +107,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
new_item = {
OPTForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)
return policy
def postprocess(self):

View File

@@ -42,116 +42,126 @@ class T5BasePolicy(Policy):
T5Stack,
)
base_policy = {
T5Stack:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
]),
T5LayerSelfAttention:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
T5LayerCrossAttention:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
T5Attention:
ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
]),
T5LayerFF:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]),
T5DenseGatedActDense:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]),
T5DenseActDense:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
}
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
])
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[T5LayerFF].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5Stack].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
return base_policy
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
return policy
def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
@@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy):
def module_policy(self):
from transformers import T5Model
base_policy = super().module_policy()
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
),
policy=base_policy,
target_key=T5Model)
return base_policy
@@ -183,14 +194,19 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
],
policy=policy,
target_key=T5ForConditionalGeneration)
return policy
def postprocess(self):
@@ -212,12 +228,14 @@ class T5EncoderPolicy(T5BasePolicy):
from transformers import T5EncoderModel
base_policy = super().module_policy()
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
),
policy=base_policy,
target_key=T5EncoderModel)
return base_policy
def postprocess(self):

View File

@@ -13,11 +13,12 @@ class ShardConfig:
Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
"""
tensor_parallel_process_group: ProcessGroup = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
@@ -33,8 +34,11 @@ class ShardConfig:
return self._tensor_parallel_size
def __post_init__(self):
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization: