[shardformer] made tensor parallelism configurable (#4144)

* [shardformer] made tensor parallelism configurable

* polish code
This commit is contained in:
Frank Lee 2023-07-04 09:57:03 +08:00
parent 74257cb446
commit 1fb0d95df0
15 changed files with 819 additions and 673 deletions

View File

@ -126,3 +126,28 @@ class Policy(ABC):
the classifier layer the classifier layer
""" """
pass pass
def append_or_create_submodule_replacement(
self, description: Union[SubModuleReplacementDescription,
List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module],
ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new submodule replacement description to the policy for the given key.
Args:
submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
# convert to list
if isinstance(description, SubModuleReplacementDescription):
description = [description]
# append or create a new description
if target_key in policy:
policy[target_key].sub_module_replacement.extend(description)
else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
return policy

View File

@ -33,89 +33,114 @@ class BertPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
base_policy = { policy = {}
BertLayer:
ModulePolicyDescription( if self.shard_config.enable_tensor_parallelism:
attribute_replacement={ policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
# 1. shard hidden size "attention.self.all_head_size":
"attention.self.all_head_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "crossattention.self.all_head_size":
"crossattention.self.all_head_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attention.self.num_attention_heads":
# 2. shard number of heads self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.self.num_attention_heads": "crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads": },
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, sub_module_replacement=[
}, SubModuleReplacementDescription(
sub_module_replacement=[ suffix="attention.self.query",
SubModuleReplacementDescription( target_module=col_nn.Linear1D_Col,
suffix="attention.self.query", ),
target_module=col_nn.Linear1D_Col, SubModuleReplacementDescription(
), suffix="attention.self.key",
SubModuleReplacementDescription( target_module=col_nn.Linear1D_Col,
suffix="attention.self.key", ),
target_module=col_nn.Linear1D_Col, SubModuleReplacementDescription(
), suffix="attention.self.value",
SubModuleReplacementDescription( target_module=col_nn.Linear1D_Col,
suffix="attention.self.value", ),
target_module=col_nn.Linear1D_Col, SubModuleReplacementDescription(
), suffix="attention.self.dropout",
SubModuleReplacementDescription( target_module=col_nn.DropoutForParallelInput,
suffix="attention.self.dropout", ),
target_module=col_nn.DropoutForParallelInput, SubModuleReplacementDescription(
), suffix="attention.output.dense",
SubModuleReplacementDescription( target_module=col_nn.Linear1D_Row,
suffix="attention.output.dense", ),
target_module=col_nn.Linear1D_Row, SubModuleReplacementDescription(
), suffix="attention.output.dropout",
SubModuleReplacementDescription( target_module=col_nn.DropoutForParallelInput,
suffix="attention.output.dropout", ),
target_module=col_nn.DropoutForParallelInput, SubModuleReplacementDescription(
), suffix="intermediate.dense",
SubModuleReplacementDescription( target_module=col_nn.Linear1D_Col,
suffix="intermediate.dense", ),
target_module=col_nn.Linear1D_Col, SubModuleReplacementDescription(
), suffix="output.dense",
SubModuleReplacementDescription( target_module=col_nn.Linear1D_Row,
suffix="output.dense", ),
target_module=col_nn.Linear1D_Row, SubModuleReplacementDescription(
), suffix="output.dropout",
SubModuleReplacementDescription( target_module=col_nn.DropoutForParallelInput,
suffix="output.dropout", )
target_module=col_nn.DropoutForParallelInput, ])
)
]), policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
BertEmbeddings: SubModuleReplacementDescription(
ModulePolicyDescription(sub_module_replacement=[ suffix="word_embeddings",
SubModuleReplacementDescription( target_module=col_nn.VocabParallelEmbedding1D,
suffix="word_embeddings", ),
target_module=col_nn.VocabParallelEmbedding1D, SubModuleReplacementDescription(
), suffix="dropout",
SubModuleReplacementDescription( target_module=col_nn.DropoutForReplicatedInput,
suffix="dropout", )
target_module=col_nn.DropoutForReplicatedInput, ])
)
])
}
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
base_policy[BertLayer].sub_module_replacement.append( # Handle bert layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.LayerNorm", suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
)) ),
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.LayerNorm", suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
)) )
base_policy[BertEmbeddings].sub_module_replacement.append( ],
SubModuleReplacementDescription( policy=policy,
target_key=BertLayer)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
suffix="LayerNorm", suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
),) )],
policy=policy,
target_key=BertEmbeddings)
return policy
def add_lm_head_policy(self, base_policy):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
# optimize for tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
policy=base_policy,
target_key=BertLMPredictionHead)
# optimize with fused normalization
if self.shard_config.enable_fused_normalization:
# Handle bert lm prediction head
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
policy=base_policy,
target_key=BertLMPredictionHead)
return base_policy return base_policy
def postprocess(self): def postprocess(self):
@ -136,35 +161,14 @@ class BertForPretrainingPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = { module_policy = self.add_lm_head_policy(module_policy)
BertLMPredictionHead:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
# append extra policy
module_policy.update(addon_module)
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param) setattr_(self.model, v, param)
return self.model return self.model
@ -176,31 +180,14 @@ class BertLMHeadModelPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = { module_policy = self.add_lm_head_policy(module_policy)
BertLMPredictionHead:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param) setattr_(self.model, v, param)
return self.model return self.model
@ -212,34 +199,14 @@ class BertForMaskedLMPolicy(BertPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = { module_policy = self.add_lm_head_policy(module_policy)
BertLMPredictionHead:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param) setattr_(self.model, v, param)
return self.model return self.model
@ -254,16 +221,18 @@ class BertForSequenceClassificationPolicy(BertPolicy):
from transformers.models.bert.modeling_bert import BertForSequenceClassification from transformers.models.bert.modeling_bert import BertForSequenceClassification
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ addon_module = {
SubModuleReplacementDescription( BertForSequenceClassification:
suffix="dropout", ModulePolicyDescription(sub_module_replacement=[
target_module=col_nn.DropoutForParallelInput, SubModuleReplacementDescription(
) suffix="dropout",
]) target_module=col_nn.DropoutForParallelInput,
} )
module_policy.update(addon_module) ])
}
module_policy.update(addon_module)
return module_policy return module_policy
@ -277,16 +246,18 @@ class BertForTokenClassificationPolicy(BertPolicy):
from transformers.models.bert.modeling_bert import BertForTokenClassification from transformers.models.bert.modeling_bert import BertForTokenClassification
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = {
BertForTokenClassification: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ addon_module = {
SubModuleReplacementDescription( BertForTokenClassification:
suffix="dropout", ModulePolicyDescription(sub_module_replacement=[
target_module=col_nn.DropoutForParallelInput, SubModuleReplacementDescription(
) suffix="dropout",
]) target_module=col_nn.DropoutForParallelInput,
} )
module_policy.update(addon_module) ])
}
module_policy.update(addon_module)
return module_policy return module_policy
@ -307,14 +278,16 @@ class BertForMultipleChoicePolicy(BertPolicy):
from transformers.models.bert.modeling_bert import BertForMultipleChoice from transformers.models.bert.modeling_bert import BertForMultipleChoice
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ addon_module = {
SubModuleReplacementDescription( BertForMultipleChoice:
suffix="dropout", ModulePolicyDescription(sub_module_replacement=[
target_module=col_nn.DropoutForParallelInput, SubModuleReplacementDescription(
) suffix="dropout",
]) target_module=col_nn.DropoutForParallelInput,
} )
module_policy.update(addon_module) ])
}
module_policy.update(addon_module)
return module_policy return module_policy

View File

@ -85,57 +85,53 @@ class BloomPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
base_policy = { policy = {}
BloomBlock:
ModulePolicyDescription( if self.shard_config.enable_tensor_parallelism:
attribute_replacement={ policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
# 1. shard hidden size "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.hidden_size": "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
"self_attention.split_size": },
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, sub_module_replacement=[
# 2. shard number of heads SubModuleReplacementDescription(
"self_attention.num_heads": suffix="self_attention.query_key_value",
self.model.config.n_head // self.shard_config.tensor_parallel_size, target_module=col_nn.Linear1D_Col,
}, ),
sub_module_replacement=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="self_attention.dense",
suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Row,
target_module=col_nn.Linear1D_Col, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="self_attention.attention_dropout",
suffix="self_attention.dense", target_module=col_nn.DropoutForParallelInput,
target_module=col_nn.Linear1D_Row, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h",
suffix="self_attention.attention_dropout", target_module=col_nn.Linear1D_Col,
target_module=col_nn.DropoutForParallelInput, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h",
suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Row,
target_module=col_nn.Linear1D_Col, ),
), ])
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", policy[BloomModel] = ModulePolicyDescription(
target_module=col_nn.Linear1D_Row, attribute_replacement={
),
]),
BloomModel:
ModulePolicyDescription(attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
}, },
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="word_embeddings", suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D, target_module=col_nn.VocabParallelEmbedding1D,
) )
]) ])
}
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([ # handle bloom model
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="ln_f", suffix="ln_f",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
@ -144,8 +140,12 @@ class BloomPolicy(Policy):
suffix="word_embeddings_layernorm", suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
) )
]) ],
base_policy[BloomBlock].sub_module_replacement.extend([ policy=policy,
target_key=BloomModel)
# handle bloom block
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
@ -154,9 +154,11 @@ class BloomPolicy(Policy):
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
) )
]) ],
policy=policy,
target_key=BloomBlock)
return base_policy return policy
def postprocess(self): def postprocess(self):
return self.model return self.model
@ -171,19 +173,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
def module_policy(self): def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForCausalLM from transformers.models.bloom.modeling_bloom import BloomForCausalLM
policy = super().module_policy() policy = super().module_policy()
# add a new item for casual lm
new_item = { # handle tensor parallelism
BloomForCausalLM: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) policy=policy,
]) target_key=BloomForCausalLM)
}
policy.update(new_item)
return policy return policy
def postprocess(self): def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
@ -191,7 +193,6 @@ class BloomForCausalLMPolicy(BloomPolicy):
param = nn.Parameter(param) param = nn.Parameter(param)
# tie weights # tie weights
setattr_(self.model, k, param)
setattr_(self.model, v, param) setattr_(self.model, v, param)
return self.model return self.model
@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
def module_policy(self): def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy() policy = super().module_policy()
# add a new item for casual lm
new_item = { # handle tensor parallelism
BloomForSequenceClassification: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) policy=policy,
]) target_key=BloomForSequenceClassification)
}
policy.update(new_item)
return policy return policy
@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
def module_policy(self): def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy() policy = super().module_policy()
# add a new item for casual lm
new_item = { # handle tensor parallelism
BloomForTokenClassification: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(suffix="classifier",
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), target_module=col_nn.Linear1D_Col,
SubModuleReplacementDescription( kwargs=dict(gather_output=True)),
suffix="dropout", SubModuleReplacementDescription(
target_module=col_nn.DropoutForReplicatedInput, suffix="dropout",
), target_module=col_nn.DropoutForReplicatedInput,
]) ),
} ],
policy.update(new_item) policy=policy,
target_key=BloomForTokenClassification)
return policy return policy

View File

@ -31,67 +31,67 @@ class GPT2Policy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
base_policy = { policy = {}
GPT2Model:
ModulePolicyDescription(sub_module_replacement=[ if self.shard_config.enable_tensor_parallelism:
SubModuleReplacementDescription( policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
suffix="wte", SubModuleReplacementDescription(
target_module=col_nn.VocabParallelEmbedding1D, suffix="wte",
), target_module=col_nn.VocabParallelEmbedding1D,
]), ),
GPT2Block: ])
ModulePolicyDescription(attribute_replacement={ policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}, },
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 3, "n_fused": 3,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_proj", suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 1, "n_fused": 1,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_proj", suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.attn_dropout", suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.resid_dropout", suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dropout", suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
]) ])
}
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
base_policy[GPT2Model].sub_module_replacement.append( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="ln_f",
suffix="ln_f", target_module=col_nn.FusedLayerNorm,
target_module=col_nn.FusedLayerNorm, ),
)) policy=policy,
target_key=GPT2Model)
base_policy[GPT2Block].sub_module_replacement.extend([ self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="ln_1", suffix="ln_1",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
@ -103,9 +103,10 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription(suffix="ln_cross_attn", SubModuleReplacementDescription(suffix="ln_cross_attn",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True) ignore_if_not_exist=True)
]) ],
policy=policy,
return base_policy target_key=GPT2Block)
return policy
def postprocess(self): def postprocess(self):
return self.model return self.model
@ -128,22 +129,22 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ addon_module = {
SubModuleReplacementDescription( GPT2LMHeadModel:
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) ModulePolicyDescription(sub_module_replacement=[
]) SubModuleReplacementDescription(
} suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
module_policy.update(addon_module) ])
}
module_policy.update(addon_module)
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"} binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param) setattr_(self.model, v, param)
return self.model return self.model
@ -158,22 +159,22 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
module_policy = super().module_policy() module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel: if self.shard_config.enable_tensor_parallelism:
ModulePolicyDescription(sub_module_replacement=[ addon_module = {
SubModuleReplacementDescription( GPT2DoubleHeadsModel:
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) ModulePolicyDescription(sub_module_replacement=[
]) SubModuleReplacementDescription(
} suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
module_policy.update(addon_module) ])
}
module_policy.update(addon_module)
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"} binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param) setattr_(self.model, v, param)
return self.model return self.model

View File

@ -28,58 +28,58 @@ class LlamaPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
base_policy = { policy = {}
LlamaDecoderLayer:
ModulePolicyDescription( if self.shard_config.enable_tensor_parallelism:
attribute_replacement={ policy[LlamaDecoderLayer] = ModulePolicyDescription(
"self_attn.hidden_size": attribute_replacement={
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size":
"self_attn.num_heads": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads":
}, self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
sub_module_replacement=[ },
SubModuleReplacementDescription( sub_module_replacement=[
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
],
),
LlamaModel:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="embed_tokens", suffix="self_attn.q_proj",
target_module=VocabParallelEmbedding1D, target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
) )
]) ],
} )
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=LlamaModel)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([ self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=FusedRMSNorm, target_module=FusedRMSNorm,
@ -88,15 +88,18 @@ class LlamaPolicy(Policy):
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=FusedRMSNorm, target_module=FusedRMSNorm,
) )
]) ],
policy=policy,
target_key=LlamaDecoderLayer)
base_policy[LlamaModel].sub_module_replacement.append( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="norm",
suffix="norm", target_module=FusedRMSNorm,
target_module=FusedRMSNorm, ),
)) policy=policy,
target_key=LlamaModel)
return base_policy return policy
def postprocess(self): def postprocess(self):
return self.model return self.model
@ -108,15 +111,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
policy = super().module_policy() policy = super().module_policy()
# add a new item for casual lm
new_item = { if self.shard_config.enable_tensor_parallelism:
LlamaForCausalLM: # add a new item for casual lm
ModulePolicyDescription(sub_module_replacement=[ new_item = {
SubModuleReplacementDescription( LlamaForCausalLM:
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) ModulePolicyDescription(sub_module_replacement=[
]) SubModuleReplacementDescription(
} suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
policy.update(new_item) ])
}
policy.update(new_item)
return policy return policy
@ -127,13 +132,14 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
policy = super().module_policy() policy = super().module_policy()
# add a new item for sequence classification if self.shard_config.enable_tensor_parallelism:
new_item = { # add a new item for sequence classification
LlamaForSequenceClassification: new_item = {
ModulePolicyDescription(sub_module_replacement=[ LlamaForSequenceClassification:
SubModuleReplacementDescription( ModulePolicyDescription(sub_module_replacement=[
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) SubModuleReplacementDescription(
]) suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
} ])
policy.update(new_item) }
policy.update(new_item)
return policy return policy

View File

@ -29,66 +29,67 @@ class OPTPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
base_policy = { policy = {}
OPTDecoder:
ModulePolicyDescription(sub_module_replacement=[ if self.shard_config.enable_tensor_parallelism:
SubModuleReplacementDescription( policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
suffix="embed_tokens", SubModuleReplacementDescription(
target_module=VocabParallelEmbedding1D, suffix="embed_tokens",
) target_module=VocabParallelEmbedding1D,
]), )
OPTDecoderLayer: ])
ModulePolicyDescription(sub_module_replacement=[ policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="fc1", suffix="fc1",
target_module=Linear1D_Col, target_module=Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="fc2", suffix="fc2",
target_module=Linear1D_Row, target_module=Linear1D_Row,
) )
]), ])
OPTAttention:
ModulePolicyDescription(attribute_replacement={ policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
}, },
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="q_proj", suffix="q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="k_proj", suffix="k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="v_proj", suffix="v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="out_proj", suffix="out_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
), ),
]), ])
}
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
base_policy[OPTDecoder].sub_module_replacement.append( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription(suffix="final_layer_norm", suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True),
target_module=FusedLayerNorm, policy=policy,
ignore_if_not_exist=True)) target_key=OPTDecoder)
base_policy[OPTDecoderLayer].sub_module_replacement.extend([ self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="self_attn_layer_norm", SubModuleReplacementDescription(suffix="self_attn_layer_norm",
target_module=FusedLayerNorm, target_module=FusedLayerNorm,
ignore_if_not_exist=True), ignore_if_not_exist=True),
SubModuleReplacementDescription(suffix="final_layer_norm", SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm, target_module=FusedLayerNorm,
ignore_if_not_exist=True) ignore_if_not_exist=True)
]) ],
policy=policy,
target_key=OPTDecoderLayer)
return base_policy return policy
def postprocess(self): def postprocess(self):
return self.model return self.model
@ -106,15 +107,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy() policy = super().module_policy()
new_item = {
OPTForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item) if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)
return policy return policy
def postprocess(self): def postprocess(self):

View File

@ -42,116 +42,126 @@ class T5BasePolicy(Policy):
T5Stack, T5Stack,
) )
base_policy = { policy = {}
T5Stack:
ModulePolicyDescription(sub_module_replacement=[ if self.shard_config.enable_tensor_parallelism:
SubModuleReplacementDescription( policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
suffix="dropout", SubModuleReplacementDescription(
target_module=DropoutForParallelInput, suffix="dropout",
), target_module=DropoutForParallelInput,
SubModuleReplacementDescription( ),
suffix="embed_tokens", SubModuleReplacementDescription(
target_module=Embedding1D, suffix="embed_tokens",
) target_module=Embedding1D,
]), )
T5LayerSelfAttention: ])
ModulePolicyDescription(sub_module_replacement=[ policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="dropout", suffix="dropout",
target_module=DropoutForParallelInput, target_module=DropoutForParallelInput,
), ),
]), ])
T5LayerCrossAttention: policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="dropout",
suffix="dropout", target_module=DropoutForParallelInput,
target_module=DropoutForParallelInput, )
) ])
]), policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
T5Attention: "d_model":
ModulePolicyDescription(attribute_replacement={ self.model.config.d_model // self.shard_config.tensor_parallel_size,
"d_model": "n_heads":
self.model.config.d_model // self.shard_config.tensor_parallel_size, self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"n_heads": "inner_dim":
self.model.config.num_heads // self.shard_config.tensor_parallel_size, self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
"inner_dim": },
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size sub_module_replacement=[
}, SubModuleReplacementDescription(
sub_module_replacement=[ suffix="q",
SubModuleReplacementDescription( target_module=Linear1D_Col,
suffix="q", ),
target_module=Linear1D_Col, SubModuleReplacementDescription(
), suffix="k",
SubModuleReplacementDescription( target_module=Linear1D_Col,
suffix="k", ),
target_module=Linear1D_Col, SubModuleReplacementDescription(
), suffix="v",
SubModuleReplacementDescription( target_module=Linear1D_Col,
suffix="v", ),
target_module=Linear1D_Col, SubModuleReplacementDescription(
), suffix="o",
SubModuleReplacementDescription( target_module=Linear1D_Row,
suffix="o", ),
target_module=Linear1D_Row, SubModuleReplacementDescription(
), suffix="relative_attention_bias",
SubModuleReplacementDescription(suffix="relative_attention_bias", target_module=Embedding1D,
target_module=Embedding1D, kwargs=dict(gather_output=False),
kwargs=dict(gather_output=False), ignore_if_not_exist=True)
ignore_if_not_exist=True) ])
]), policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
T5LayerFF: SubModuleReplacementDescription(
ModulePolicyDescription(sub_module_replacement=[ suffix="dropout",
SubModuleReplacementDescription( target_module=DropoutForParallelInput,
suffix="dropout", ),
target_module=DropoutForParallelInput, ])
), policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
]), SubModuleReplacementDescription(
T5DenseGatedActDense: suffix="wi_0",
ModulePolicyDescription(sub_module_replacement=[ target_module=Linear1D_Col,
SubModuleReplacementDescription( ),
suffix="wi_0", SubModuleReplacementDescription(
target_module=Linear1D_Col, suffix="wi_1",
), target_module=Linear1D_Row,
SubModuleReplacementDescription( ),
suffix="wi_1", SubModuleReplacementDescription(
target_module=Linear1D_Row, suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="dropout",
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), target_module=DropoutForParallelInput,
SubModuleReplacementDescription( )
suffix="dropout", ])
target_module=DropoutForParallelInput, policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
) SubModuleReplacementDescription(
]), suffix="wi",
T5DenseActDense: target_module=Linear1D_Col,
ModulePolicyDescription(sub_module_replacement=[ ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="wi", suffix="wo",
target_module=Linear1D_Col, target_module=Linear1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="wo", suffix="dropout",
target_module=Linear1D_Row, target_module=DropoutForParallelInput,
), )
SubModuleReplacementDescription( ])
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
}
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
base_policy[T5LayerFF].sub_module_replacement.append( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) suffix="layer_norm",
base_policy[T5LayerSelfAttention].sub_module_replacement.append( target_module=FusedRMSNorm,
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) ),
base_policy[T5LayerCrossAttention].sub_module_replacement.append( policy=policy,
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) target_key=T5LayerFF)
base_policy[T5Stack].sub_module_replacement.append( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm)) suffix="layer_norm",
target_module=FusedRMSNorm,
return base_policy ),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
return policy
def postprocess(self): def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy):
def module_policy(self): def module_policy(self):
from transformers import T5Model from transformers import T5Model
base_policy = super().module_policy() base_policy = super().module_policy()
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared", suffix="shared",
target_module=VocabParallelEmbedding1D, target_module=VocabParallelEmbedding1D,
) ),
]) policy=base_policy,
target_key=T5Model)
return base_policy return base_policy
@ -183,14 +194,19 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
from transformers import T5ForConditionalGeneration from transformers import T5ForConditionalGeneration
policy = super().module_policy() policy = super().module_policy()
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( if self.shard_config.enable_tensor_parallelism:
suffix="shared", self.append_or_create_submodule_replacement(description=[
target_module=VocabParallelEmbedding1D, SubModuleReplacementDescription(
), suffix="shared",
SubModuleReplacementDescription( target_module=VocabParallelEmbedding1D,
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) ),
]) SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
],
policy=policy,
target_key=T5ForConditionalGeneration)
return policy return policy
def postprocess(self): def postprocess(self):
@ -212,12 +228,14 @@ class T5EncoderPolicy(T5BasePolicy):
from transformers import T5EncoderModel from transformers import T5EncoderModel
base_policy = super().module_policy() base_policy = super().module_policy()
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared", suffix="shared",
target_module=VocabParallelEmbedding1D, target_module=VocabParallelEmbedding1D,
) ),
]) policy=base_policy,
target_key=T5EncoderModel)
return base_policy return base_policy
def postprocess(self): def postprocess(self):

View File

@ -13,11 +13,12 @@ class ShardConfig:
Args: Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True.
enable_all_optimization (bool): Whether to turn on all optimization, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False.
""" """
tensor_parallel_process_group: ProcessGroup = None tensor_parallel_process_group: ProcessGroup = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False enable_fused_normalization: bool = False
enable_all_optimization: bool = False enable_all_optimization: bool = False
@ -33,8 +34,11 @@ class ShardConfig:
return self._tensor_parallel_size return self._tensor_parallel_size
def __post_init__(self): def __post_init__(self):
# get the parallel size if not self.enable_tensor_parallelism:
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) self._tensor_parallel_size = 1
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True # turn on all optimization if all_optimization is set to True
if self.enable_all_optimization: if self.enable_all_optimization:

View File

@ -3,12 +3,13 @@ import copy
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(model_fn): def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
# create new model # create new model
org_model = model_fn().cuda() org_model = model_fn().cuda()
# shard model # shard model
shard_config = ShardConfig(enable_fused_normalization=True) shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model_copy).cuda() sharded_model = shard_former.optimize(model_copy).cuda()

View File

@ -3,7 +3,14 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward from tests.test_shardformer.test_model._utils import build_model, run_forward
@ -33,34 +40,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# compare self attention grad # compare self attention grad
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
all_shard_grad = torch.cat(shard_grad_list, dim=0) shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
# compare embedding grad # compare embedding grad
org_grad = bert.embeddings.word_embeddings.weight.grad org_grad = bert.embeddings.word_embeddings.weight.grad
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
shard_weight = sharded_bert.embeddings.word_embeddings.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_bert(rank, world_size, port): def check_bert(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist

View File

@ -3,7 +3,14 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward from tests.test_shardformer.test_model._utils import build_model, run_forward
@ -32,10 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check attention grad # check attention grad
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad
shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
torch.distributed.all_gather(shard_grad_list, shard_grad) shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
all_shard_grad = torch.cat(shard_grad_list, dim=0) torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@ -43,25 +54,33 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check embedding weights # check embedding weights
org_grad = bloom.word_embeddings.weight.grad org_grad = bloom.word_embeddings.weight.grad
shard_grad = sharded_bloom.word_embeddings.weight.grad shard_grad = sharded_bloom.word_embeddings.weight.grad
shard_weight = sharded_bloom.word_embeddings.weight
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
torch.distributed.all_gather(shard_grad_list, shard_grad) shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
all_shard_grad = torch.cat(shard_grad_list, dim=0) torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_bloom(rank, world_size, port): def check_bloom(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bloom_test()
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist

View File

@ -3,7 +3,14 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward from tests.test_shardformer.test_model._utils import build_model, run_forward
@ -32,11 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check mlp grad # check mlp grad
org_grad = org_model.h[0].mlp.c_fc.weight.grad org_grad = org_model.h[0].mlp.c_fc.weight.grad
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
shard_weight = sharded_model.h[0].mlp.c_fc.weight
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
all_shard_grad = torch.cat(shard_grad_list, dim=1) shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=1)
else:
all_shard_grad = shard_grad
assert torch.allclose( assert torch.allclose(
org_grad, all_shard_grad, org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
@ -44,25 +54,33 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check embedding weights # check embedding weights
org_grad = org_model.wte.weight.grad org_grad = org_model.wte.weight.grad
shard_grad = sharded_model.wte.weight.grad shard_grad = sharded_model.wte.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_weight = sharded_model.wte.weight
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose( assert torch.allclose(
org_grad, all_shard_grad, org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_gpt2(rank, world_size, port): def check_gpt2(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_test()
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist

View File

@ -5,7 +5,14 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward from tests.test_shardformer.test_model._utils import build_model, run_forward
@ -37,33 +44,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check attention grad # check attention grad
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
# check embedding grad # check embedding grad
org_grad = llama_model.embed_tokens.weight.grad org_grad = llama_model.embed_tokens.weight.grad
shard_grad = shard_llama_model.embed_tokens.weight.grad shard_grad = shard_llama_model.embed_tokens.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_weight = shard_llama_model.embed_tokens.weight
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_llama(rank, world_size, port): def check_llama(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_llama()
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist

View File

@ -6,10 +6,11 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import ( from colossalai.testing import (
assert_hf_output_close, assert_hf_output_close,
check_state_dict_equal,
clear_cache_before_run, clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use, rerun_if_address_is_in_use,
spawn, spawn,
) )
@ -42,32 +43,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check attention grad # check attention grad
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
# check embedding grad # check embedding grad
org_grad = opt_model.decoder.embed_tokens.weight.grad org_grad = opt_model.decoder.embed_tokens.weight.grad
shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_weight = shard_opt_model.decoder.embed_tokens.weight
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_OPTModel(rank, world_size, port): def check_OPTModel(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist

View File

@ -5,7 +5,14 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward from tests.test_shardformer.test_model._utils import build_model, run_forward
@ -27,19 +34,28 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check attention grad # check attention grad
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
all_shard_grad = torch.cat(shard_grad_list, dim=0) shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
# check self attention embed # check self attention embed
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=1) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=1)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@ -52,23 +68,32 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
shard_grad = sharded_model.shared.weight.grad shard_grad = sharded_model.shared.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_weight = sharded_model.shared.weight
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0) if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad, assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_t5(rank, world_size, port): def check_t5(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist