mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable * polish code
This commit is contained in:
parent
74257cb446
commit
1fb0d95df0
@ -126,3 +126,28 @@ class Policy(ABC):
|
|||||||
the classifier layer
|
the classifier layer
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def append_or_create_submodule_replacement(
|
||||||
|
self, description: Union[SubModuleReplacementDescription,
|
||||||
|
List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module],
|
||||||
|
ModulePolicyDescription],
|
||||||
|
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
r"""
|
||||||
|
Append or create a new submodule replacement description to the policy for the given key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
|
||||||
|
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
|
||||||
|
target_key (Union[str, nn.Module]): the key of the policy to be updated
|
||||||
|
"""
|
||||||
|
# convert to list
|
||||||
|
if isinstance(description, SubModuleReplacementDescription):
|
||||||
|
description = [description]
|
||||||
|
|
||||||
|
# append or create a new description
|
||||||
|
if target_key in policy:
|
||||||
|
policy[target_key].sub_module_replacement.extend(description)
|
||||||
|
else:
|
||||||
|
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
@ -33,89 +33,114 @@ class BertPolicy(Policy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
BertLayer:
|
|
||||||
ModulePolicyDescription(
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
attribute_replacement={
|
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
||||||
# 1. shard hidden size
|
"attention.self.all_head_size":
|
||||||
"attention.self.all_head_size":
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"crossattention.self.all_head_size":
|
||||||
"crossattention.self.all_head_size":
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attention.self.num_attention_heads":
|
||||||
# 2. shard number of heads
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
"attention.self.num_attention_heads":
|
"crossattention.self.num_attention_heads":
|
||||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
"crossattention.self.num_attention_heads":
|
},
|
||||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
sub_module_replacement=[
|
||||||
},
|
SubModuleReplacementDescription(
|
||||||
sub_module_replacement=[
|
suffix="attention.self.query",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.Linear1D_Col,
|
||||||
suffix="attention.self.query",
|
),
|
||||||
target_module=col_nn.Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="attention.self.key",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.Linear1D_Col,
|
||||||
suffix="attention.self.key",
|
),
|
||||||
target_module=col_nn.Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="attention.self.value",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.Linear1D_Col,
|
||||||
suffix="attention.self.value",
|
),
|
||||||
target_module=col_nn.Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="attention.self.dropout",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
suffix="attention.self.dropout",
|
),
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="attention.output.dense",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.Linear1D_Row,
|
||||||
suffix="attention.output.dense",
|
),
|
||||||
target_module=col_nn.Linear1D_Row,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="attention.output.dropout",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
suffix="attention.output.dropout",
|
),
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="intermediate.dense",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.Linear1D_Col,
|
||||||
suffix="intermediate.dense",
|
),
|
||||||
target_module=col_nn.Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="output.dense",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.Linear1D_Row,
|
||||||
suffix="output.dense",
|
),
|
||||||
target_module=col_nn.Linear1D_Row,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="output.dropout",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
suffix="output.dropout",
|
)
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
])
|
||||||
)
|
|
||||||
]),
|
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
BertEmbeddings:
|
SubModuleReplacementDescription(
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
suffix="word_embeddings",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.VocabParallelEmbedding1D,
|
||||||
suffix="word_embeddings",
|
),
|
||||||
target_module=col_nn.VocabParallelEmbedding1D,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="dropout",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
suffix="dropout",
|
)
|
||||||
target_module=col_nn.DropoutForReplicatedInput,
|
])
|
||||||
)
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
base_policy[BertLayer].sub_module_replacement.append(
|
# Handle bert layer
|
||||||
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.LayerNorm",
|
suffix="attention.output.LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
),
|
||||||
base_policy[BertLayer].sub_module_replacement.append(
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.LayerNorm",
|
suffix="output.LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
))
|
)
|
||||||
base_policy[BertEmbeddings].sub_module_replacement.append(
|
],
|
||||||
SubModuleReplacementDescription(
|
policy=policy,
|
||||||
|
target_key=BertLayer)
|
||||||
|
|
||||||
|
# handle embedding layer
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[SubModuleReplacementDescription(
|
||||||
suffix="LayerNorm",
|
suffix="LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
),)
|
)],
|
||||||
|
policy=policy,
|
||||||
|
target_key=BertEmbeddings)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def add_lm_head_policy(self, base_policy):
|
||||||
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||||
|
|
||||||
|
# optimize for tensor parallelism
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||||
|
policy=base_policy,
|
||||||
|
target_key=BertLMPredictionHead)
|
||||||
|
|
||||||
|
# optimize with fused normalization
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
# Handle bert lm prediction head
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="transform.LayerNorm",
|
||||||
|
target_module=col_nn.FusedLayerNorm,
|
||||||
|
),
|
||||||
|
policy=base_policy,
|
||||||
|
target_key=BertLMPredictionHead)
|
||||||
return base_policy
|
return base_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
@ -136,35 +161,14 @@ class BertForPretrainingPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
module_policy = self.add_lm_head_policy(module_policy)
|
||||||
BertLMPredictionHead:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
|
||||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="transform.LayerNorm",
|
|
||||||
target_module=col_nn.FusedLayerNorm,
|
|
||||||
))
|
|
||||||
|
|
||||||
# append extra policy
|
|
||||||
module_policy.update(addon_module)
|
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
param = nn.Parameter(param)
|
|
||||||
setattr_(self.model, k, param)
|
|
||||||
setattr_(self.model, v, param)
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -176,31 +180,14 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
module_policy = self.add_lm_head_policy(module_policy)
|
||||||
BertLMPredictionHead:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
|
||||||
])
|
|
||||||
}
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
|
||||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="transform.LayerNorm",
|
|
||||||
target_module=col_nn.FusedLayerNorm,
|
|
||||||
))
|
|
||||||
module_policy.update(addon_module)
|
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
param = nn.Parameter(param)
|
|
||||||
setattr_(self.model, k, param)
|
|
||||||
setattr_(self.model, v, param)
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -212,34 +199,14 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
module_policy = self.add_lm_head_policy(module_policy)
|
||||||
BertLMPredictionHead:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
|
||||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="transform.LayerNorm",
|
|
||||||
target_module=col_nn.FusedLayerNorm,
|
|
||||||
))
|
|
||||||
|
|
||||||
module_policy.update(addon_module)
|
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
param = nn.Parameter(param)
|
|
||||||
setattr_(self.model, k, param)
|
|
||||||
setattr_(self.model, v, param)
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -254,16 +221,18 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
|
||||||
BertForSequenceClassification:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
addon_module = {
|
||||||
SubModuleReplacementDescription(
|
BertForSequenceClassification:
|
||||||
suffix="dropout",
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
SubModuleReplacementDescription(
|
||||||
)
|
suffix="dropout",
|
||||||
])
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
}
|
)
|
||||||
module_policy.update(addon_module)
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
|
|
||||||
@ -277,16 +246,18 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
|
||||||
BertForTokenClassification:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
addon_module = {
|
||||||
SubModuleReplacementDescription(
|
BertForTokenClassification:
|
||||||
suffix="dropout",
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
SubModuleReplacementDescription(
|
||||||
)
|
suffix="dropout",
|
||||||
])
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
}
|
)
|
||||||
module_policy.update(addon_module)
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
|
|
||||||
@ -307,14 +278,16 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
|
||||||
BertForMultipleChoice:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
addon_module = {
|
||||||
SubModuleReplacementDescription(
|
BertForMultipleChoice:
|
||||||
suffix="dropout",
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
SubModuleReplacementDescription(
|
||||||
)
|
suffix="dropout",
|
||||||
])
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
}
|
)
|
||||||
module_policy.update(addon_module)
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
@ -85,57 +85,53 @@ class BloomPolicy(Policy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
|
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
BloomBlock:
|
|
||||||
ModulePolicyDescription(
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
attribute_replacement={
|
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
||||||
# 1. shard hidden size
|
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.hidden_size":
|
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.split_size":
|
},
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
sub_module_replacement=[
|
||||||
# 2. shard number of heads
|
SubModuleReplacementDescription(
|
||||||
"self_attention.num_heads":
|
suffix="self_attention.query_key_value",
|
||||||
self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
target_module=col_nn.Linear1D_Col,
|
||||||
},
|
),
|
||||||
sub_module_replacement=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="self_attention.dense",
|
||||||
suffix="self_attention.query_key_value",
|
target_module=col_nn.Linear1D_Row,
|
||||||
target_module=col_nn.Linear1D_Col,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="self_attention.attention_dropout",
|
||||||
suffix="self_attention.dense",
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
target_module=col_nn.Linear1D_Row,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="mlp.dense_h_to_4h",
|
||||||
suffix="self_attention.attention_dropout",
|
target_module=col_nn.Linear1D_Col,
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="mlp.dense_4h_to_h",
|
||||||
suffix="mlp.dense_h_to_4h",
|
target_module=col_nn.Linear1D_Row,
|
||||||
target_module=col_nn.Linear1D_Col,
|
),
|
||||||
),
|
])
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.dense_4h_to_h",
|
policy[BloomModel] = ModulePolicyDescription(
|
||||||
target_module=col_nn.Linear1D_Row,
|
attribute_replacement={
|
||||||
),
|
|
||||||
]),
|
|
||||||
BloomModel:
|
|
||||||
ModulePolicyDescription(attribute_replacement={
|
|
||||||
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
|
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="word_embeddings",
|
suffix="word_embeddings",
|
||||||
target_module=col_nn.VocabParallelEmbedding1D,
|
target_module=col_nn.VocabParallelEmbedding1D,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
base_policy[BloomModel].sub_module_replacement.extend([
|
# handle bloom model
|
||||||
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="ln_f",
|
suffix="ln_f",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
@ -144,8 +140,12 @@ class BloomPolicy(Policy):
|
|||||||
suffix="word_embeddings_layernorm",
|
suffix="word_embeddings_layernorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
)
|
)
|
||||||
])
|
],
|
||||||
base_policy[BloomBlock].sub_module_replacement.extend([
|
policy=policy,
|
||||||
|
target_key=BloomModel)
|
||||||
|
|
||||||
|
# handle bloom block
|
||||||
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
@ -154,9 +154,11 @@ class BloomPolicy(Policy):
|
|||||||
suffix="post_attention_layernorm",
|
suffix="post_attention_layernorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
)
|
)
|
||||||
])
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=BloomBlock)
|
||||||
|
|
||||||
return base_policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
@ -171,19 +173,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
|
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# add a new item for casual lm
|
|
||||||
new_item = {
|
# handle tensor parallelism
|
||||||
BloomForCausalLM:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
policy=policy,
|
||||||
])
|
target_key=BloomForCausalLM)
|
||||||
}
|
|
||||||
policy.update(new_item)
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||||
|
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
|
|
||||||
@ -191,7 +193,6 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||||||
param = nn.Parameter(param)
|
param = nn.Parameter(param)
|
||||||
|
|
||||||
# tie weights
|
# tie weights
|
||||||
setattr_(self.model, k, param)
|
|
||||||
setattr_(self.model, v, param)
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
|
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# add a new item for casual lm
|
|
||||||
new_item = {
|
# handle tensor parallelism
|
||||||
BloomForSequenceClassification:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
policy=policy,
|
||||||
])
|
target_key=BloomForSequenceClassification)
|
||||||
}
|
|
||||||
policy.update(new_item)
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
|
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# add a new item for casual lm
|
|
||||||
new_item = {
|
# handle tensor parallelism
|
||||||
BloomForTokenClassification:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(suffix="classifier",
|
||||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
|
target_module=col_nn.Linear1D_Col,
|
||||||
SubModuleReplacementDescription(
|
kwargs=dict(gather_output=True)),
|
||||||
suffix="dropout",
|
SubModuleReplacementDescription(
|
||||||
target_module=col_nn.DropoutForReplicatedInput,
|
suffix="dropout",
|
||||||
),
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
])
|
),
|
||||||
}
|
],
|
||||||
policy.update(new_item)
|
policy=policy,
|
||||||
|
target_key=BloomForTokenClassification)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,67 +31,67 @@ class GPT2Policy(Policy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
GPT2Model:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
SubModuleReplacementDescription(
|
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
suffix="wte",
|
SubModuleReplacementDescription(
|
||||||
target_module=col_nn.VocabParallelEmbedding1D,
|
suffix="wte",
|
||||||
),
|
target_module=col_nn.VocabParallelEmbedding1D,
|
||||||
]),
|
),
|
||||||
GPT2Block:
|
])
|
||||||
ModulePolicyDescription(attribute_replacement={
|
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.c_attn",
|
suffix="attn.c_attn",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"n_fused": 3,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.c_proj",
|
suffix="attn.c_proj",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.c_fc",
|
suffix="mlp.c_fc",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 1,
|
"n_fused": 1,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.c_proj",
|
suffix="mlp.c_proj",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.attn_dropout",
|
suffix="attn.attn_dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.resid_dropout",
|
suffix="attn.resid_dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dropout",
|
suffix="mlp.dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
base_policy[GPT2Model].sub_module_replacement.append(
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="ln_f",
|
||||||
suffix="ln_f",
|
target_module=col_nn.FusedLayerNorm,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
))
|
policy=policy,
|
||||||
|
target_key=GPT2Model)
|
||||||
|
|
||||||
base_policy[GPT2Block].sub_module_replacement.extend([
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="ln_1",
|
suffix="ln_1",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
@ -103,9 +103,10 @@ class GPT2Policy(Policy):
|
|||||||
SubModuleReplacementDescription(suffix="ln_cross_attn",
|
SubModuleReplacementDescription(suffix="ln_cross_attn",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=col_nn.FusedLayerNorm,
|
||||||
ignore_if_not_exist=True)
|
ignore_if_not_exist=True)
|
||||||
])
|
],
|
||||||
|
policy=policy,
|
||||||
return base_policy
|
target_key=GPT2Block)
|
||||||
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
@ -128,22 +129,22 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
|
||||||
GPT2LMHeadModel:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
addon_module = {
|
||||||
SubModuleReplacementDescription(
|
GPT2LMHeadModel:
|
||||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
])
|
SubModuleReplacementDescription(
|
||||||
}
|
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
||||||
module_policy.update(addon_module)
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
param = nn.Parameter(param)
|
|
||||||
setattr_(self.model, k, param)
|
|
||||||
setattr_(self.model, v, param)
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -158,22 +159,22 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
|
||||||
GPT2DoubleHeadsModel:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
addon_module = {
|
||||||
SubModuleReplacementDescription(
|
GPT2DoubleHeadsModel:
|
||||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
])
|
SubModuleReplacementDescription(
|
||||||
}
|
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
||||||
module_policy.update(addon_module)
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
param = nn.Parameter(param)
|
|
||||||
setattr_(self.model, k, param)
|
|
||||||
setattr_(self.model, v, param)
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -28,58 +28,58 @@ class LlamaPolicy(Policy):
|
|||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
LlamaDecoderLayer:
|
|
||||||
ModulePolicyDescription(
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
attribute_replacement={
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
"self_attn.hidden_size":
|
attribute_replacement={
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size":
|
||||||
"self_attn.num_heads":
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads":
|
||||||
},
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
sub_module_replacement=[
|
},
|
||||||
SubModuleReplacementDescription(
|
sub_module_replacement=[
|
||||||
suffix="self_attn.q_proj",
|
|
||||||
target_module=Linear1D_Col,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.k_proj",
|
|
||||||
target_module=Linear1D_Col,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.v_proj",
|
|
||||||
target_module=Linear1D_Col,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.o_proj",
|
|
||||||
target_module=Linear1D_Row,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.gate_proj",
|
|
||||||
target_module=Linear1D_Col,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.up_proj",
|
|
||||||
target_module=Linear1D_Col,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.down_proj",
|
|
||||||
target_module=Linear1D_Row,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
LlamaModel:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="embed_tokens",
|
suffix="self_attn.q_proj",
|
||||||
target_module=VocabParallelEmbedding1D,
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
)
|
)
|
||||||
])
|
],
|
||||||
}
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
target_module=VocabParallelEmbedding1D,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=FusedRMSNorm,
|
target_module=FusedRMSNorm,
|
||||||
@ -88,15 +88,18 @@ class LlamaPolicy(Policy):
|
|||||||
suffix="post_attention_layernorm",
|
suffix="post_attention_layernorm",
|
||||||
target_module=FusedRMSNorm,
|
target_module=FusedRMSNorm,
|
||||||
)
|
)
|
||||||
])
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaDecoderLayer)
|
||||||
|
|
||||||
base_policy[LlamaModel].sub_module_replacement.append(
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="norm",
|
||||||
suffix="norm",
|
target_module=FusedRMSNorm,
|
||||||
target_module=FusedRMSNorm,
|
),
|
||||||
))
|
policy=policy,
|
||||||
|
target_key=LlamaModel)
|
||||||
|
|
||||||
return base_policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
@ -108,15 +111,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# add a new item for casual lm
|
|
||||||
new_item = {
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
LlamaForCausalLM:
|
# add a new item for casual lm
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
new_item = {
|
||||||
SubModuleReplacementDescription(
|
LlamaForCausalLM:
|
||||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
])
|
SubModuleReplacementDescription(
|
||||||
}
|
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||||
policy.update(new_item)
|
])
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
@ -127,13 +132,14 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
# add a new item for sequence classification
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
new_item = {
|
# add a new item for sequence classification
|
||||||
LlamaForSequenceClassification:
|
new_item = {
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
LlamaForSequenceClassification:
|
||||||
SubModuleReplacementDescription(
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
SubModuleReplacementDescription(
|
||||||
])
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||||
}
|
])
|
||||||
policy.update(new_item)
|
}
|
||||||
|
policy.update(new_item)
|
||||||
return policy
|
return policy
|
||||||
|
@ -29,66 +29,67 @@ class OPTPolicy(Policy):
|
|||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
OPTDecoder:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
SubModuleReplacementDescription(
|
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
suffix="embed_tokens",
|
SubModuleReplacementDescription(
|
||||||
target_module=VocabParallelEmbedding1D,
|
suffix="embed_tokens",
|
||||||
)
|
target_module=VocabParallelEmbedding1D,
|
||||||
]),
|
)
|
||||||
OPTDecoderLayer:
|
])
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="fc1",
|
suffix="fc1",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="fc2",
|
suffix="fc2",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
)
|
)
|
||||||
]),
|
])
|
||||||
OPTAttention:
|
|
||||||
ModulePolicyDescription(attribute_replacement={
|
policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={
|
||||||
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
|
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
|
||||||
},
|
},
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="q_proj",
|
suffix="q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="k_proj",
|
suffix="k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="v_proj",
|
suffix="v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="out_proj",
|
suffix="out_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
),
|
),
|
||||||
]),
|
])
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
base_policy[OPTDecoder].sub_module_replacement.append(
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True),
|
||||||
target_module=FusedLayerNorm,
|
policy=policy,
|
||||||
ignore_if_not_exist=True))
|
target_key=OPTDecoder)
|
||||||
base_policy[OPTDecoderLayer].sub_module_replacement.extend([
|
self.append_or_create_submodule_replacement(description=[
|
||||||
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
|
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
|
||||||
target_module=FusedLayerNorm,
|
target_module=FusedLayerNorm,
|
||||||
ignore_if_not_exist=True),
|
ignore_if_not_exist=True),
|
||||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
SubModuleReplacementDescription(suffix="final_layer_norm",
|
||||||
target_module=FusedLayerNorm,
|
target_module=FusedLayerNorm,
|
||||||
ignore_if_not_exist=True)
|
ignore_if_not_exist=True)
|
||||||
])
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=OPTDecoderLayer)
|
||||||
|
|
||||||
return base_policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
@ -106,15 +107,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
new_item = {
|
|
||||||
OPTForCausalLM:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
policy.update(new_item)
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||||
|
policy=policy,
|
||||||
|
target_key=OPTForCausalLM)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
@ -42,116 +42,126 @@ class T5BasePolicy(Policy):
|
|||||||
T5Stack,
|
T5Stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
base_policy = {
|
policy = {}
|
||||||
T5Stack:
|
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
SubModuleReplacementDescription(
|
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
suffix="dropout",
|
SubModuleReplacementDescription(
|
||||||
target_module=DropoutForParallelInput,
|
suffix="dropout",
|
||||||
),
|
target_module=DropoutForParallelInput,
|
||||||
SubModuleReplacementDescription(
|
),
|
||||||
suffix="embed_tokens",
|
SubModuleReplacementDescription(
|
||||||
target_module=Embedding1D,
|
suffix="embed_tokens",
|
||||||
)
|
target_module=Embedding1D,
|
||||||
]),
|
)
|
||||||
T5LayerSelfAttention:
|
])
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=DropoutForParallelInput,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
]),
|
])
|
||||||
T5LayerCrossAttention:
|
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="dropout",
|
||||||
suffix="dropout",
|
target_module=DropoutForParallelInput,
|
||||||
target_module=DropoutForParallelInput,
|
)
|
||||||
)
|
])
|
||||||
]),
|
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
|
||||||
T5Attention:
|
"d_model":
|
||||||
ModulePolicyDescription(attribute_replacement={
|
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||||
"d_model":
|
"n_heads":
|
||||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||||
"n_heads":
|
"inner_dim":
|
||||||
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||||
"inner_dim":
|
},
|
||||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
sub_module_replacement=[
|
||||||
},
|
SubModuleReplacementDescription(
|
||||||
sub_module_replacement=[
|
suffix="q",
|
||||||
SubModuleReplacementDescription(
|
target_module=Linear1D_Col,
|
||||||
suffix="q",
|
),
|
||||||
target_module=Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="k",
|
||||||
SubModuleReplacementDescription(
|
target_module=Linear1D_Col,
|
||||||
suffix="k",
|
),
|
||||||
target_module=Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="v",
|
||||||
SubModuleReplacementDescription(
|
target_module=Linear1D_Col,
|
||||||
suffix="v",
|
),
|
||||||
target_module=Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="o",
|
||||||
SubModuleReplacementDescription(
|
target_module=Linear1D_Row,
|
||||||
suffix="o",
|
),
|
||||||
target_module=Linear1D_Row,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="relative_attention_bias",
|
||||||
SubModuleReplacementDescription(suffix="relative_attention_bias",
|
target_module=Embedding1D,
|
||||||
target_module=Embedding1D,
|
kwargs=dict(gather_output=False),
|
||||||
kwargs=dict(gather_output=False),
|
ignore_if_not_exist=True)
|
||||||
ignore_if_not_exist=True)
|
])
|
||||||
]),
|
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
T5LayerFF:
|
SubModuleReplacementDescription(
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
suffix="dropout",
|
||||||
SubModuleReplacementDescription(
|
target_module=DropoutForParallelInput,
|
||||||
suffix="dropout",
|
),
|
||||||
target_module=DropoutForParallelInput,
|
])
|
||||||
),
|
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
]),
|
SubModuleReplacementDescription(
|
||||||
T5DenseGatedActDense:
|
suffix="wi_0",
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
target_module=Linear1D_Col,
|
||||||
SubModuleReplacementDescription(
|
),
|
||||||
suffix="wi_0",
|
SubModuleReplacementDescription(
|
||||||
target_module=Linear1D_Col,
|
suffix="wi_1",
|
||||||
),
|
target_module=Linear1D_Row,
|
||||||
SubModuleReplacementDescription(
|
),
|
||||||
suffix="wi_1",
|
SubModuleReplacementDescription(
|
||||||
target_module=Linear1D_Row,
|
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="dropout",
|
||||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
target_module=DropoutForParallelInput,
|
||||||
SubModuleReplacementDescription(
|
)
|
||||||
suffix="dropout",
|
])
|
||||||
target_module=DropoutForParallelInput,
|
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
)
|
SubModuleReplacementDescription(
|
||||||
]),
|
suffix="wi",
|
||||||
T5DenseActDense:
|
target_module=Linear1D_Col,
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="wi",
|
suffix="wo",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="wo",
|
suffix="dropout",
|
||||||
target_module=Linear1D_Row,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
)
|
||||||
SubModuleReplacementDescription(
|
])
|
||||||
suffix="dropout",
|
|
||||||
target_module=DropoutForParallelInput,
|
|
||||||
)
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
base_policy[T5LayerFF].sub_module_replacement.append(
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
suffix="layer_norm",
|
||||||
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
|
target_module=FusedRMSNorm,
|
||||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
),
|
||||||
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
|
policy=policy,
|
||||||
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
|
target_key=T5LayerFF)
|
||||||
base_policy[T5Stack].sub_module_replacement.append(
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
|
suffix="layer_norm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
return base_policy
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=T5LayerFF)
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="layer_norm", target_module=FusedRMSNorm),
|
||||||
|
policy=policy,
|
||||||
|
target_key=T5LayerSelfAttention)
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="layer_norm", target_module=FusedRMSNorm),
|
||||||
|
policy=policy,
|
||||||
|
target_key=T5LayerCrossAttention)
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
|
suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||||
|
policy=policy,
|
||||||
|
target_key=T5Stack)
|
||||||
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||||
@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy):
|
|||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import T5Model
|
from transformers import T5Model
|
||||||
|
|
||||||
base_policy = super().module_policy()
|
base_policy = super().module_policy()
|
||||||
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
suffix="shared",
|
suffix="shared",
|
||||||
target_module=VocabParallelEmbedding1D,
|
target_module=VocabParallelEmbedding1D,
|
||||||
)
|
),
|
||||||
])
|
policy=base_policy,
|
||||||
|
target_key=T5Model)
|
||||||
return base_policy
|
return base_policy
|
||||||
|
|
||||||
|
|
||||||
@ -183,14 +194,19 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||||||
from transformers import T5ForConditionalGeneration
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
suffix="shared",
|
self.append_or_create_submodule_replacement(description=[
|
||||||
target_module=VocabParallelEmbedding1D,
|
SubModuleReplacementDescription(
|
||||||
),
|
suffix="shared",
|
||||||
SubModuleReplacementDescription(
|
target_module=VocabParallelEmbedding1D,
|
||||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
),
|
||||||
])
|
SubModuleReplacementDescription(suffix="lm_head",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True))
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=T5ForConditionalGeneration)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
@ -212,12 +228,14 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||||||
from transformers import T5EncoderModel
|
from transformers import T5EncoderModel
|
||||||
|
|
||||||
base_policy = super().module_policy()
|
base_policy = super().module_policy()
|
||||||
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||||
suffix="shared",
|
suffix="shared",
|
||||||
target_module=VocabParallelEmbedding1D,
|
target_module=VocabParallelEmbedding1D,
|
||||||
)
|
),
|
||||||
])
|
policy=base_policy,
|
||||||
|
target_key=T5EncoderModel)
|
||||||
return base_policy
|
return base_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
@ -13,11 +13,12 @@ class ShardConfig:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
||||||
|
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
||||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||||
enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True.
|
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: ProcessGroup = None
|
tensor_parallel_process_group: ProcessGroup = None
|
||||||
|
enable_tensor_parallelism: bool = True
|
||||||
enable_fused_normalization: bool = False
|
enable_fused_normalization: bool = False
|
||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
|
|
||||||
@ -33,8 +34,11 @@ class ShardConfig:
|
|||||||
return self._tensor_parallel_size
|
return self._tensor_parallel_size
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# get the parallel size
|
if not self.enable_tensor_parallelism:
|
||||||
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
self._tensor_parallel_size = 1
|
||||||
|
else:
|
||||||
|
# get the parallel size
|
||||||
|
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||||
|
|
||||||
# turn on all optimization if all_optimization is set to True
|
# turn on all optimization if all_optimization is set to True
|
||||||
if self.enable_all_optimization:
|
if self.enable_all_optimization:
|
||||||
|
@ -3,12 +3,13 @@ import copy
|
|||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
|
||||||
|
|
||||||
def build_model(model_fn):
|
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
|
||||||
# create new model
|
# create new model
|
||||||
org_model = model_fn().cuda()
|
org_model = model_fn().cuda()
|
||||||
|
|
||||||
# shard model
|
# shard model
|
||||||
shard_config = ShardConfig(enable_fused_normalization=True)
|
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||||
|
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model = shard_former.optimize(model_copy).cuda()
|
sharded_model = shard_former.optimize(model_copy).cuda()
|
||||||
|
@ -3,7 +3,14 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
@ -33,34 +40,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# compare self attention grad
|
# compare self attention grad
|
||||||
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
|
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
|
||||||
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
|
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
|
||||||
|
shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
# compare embedding grad
|
# compare embedding grad
|
||||||
org_grad = bert.embeddings.word_embeddings.weight.grad
|
org_grad = bert.embeddings.word_embeddings.weight.grad
|
||||||
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
|
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
|
||||||
|
shard_weight = sharded_bert.embeddings.word_embeddings.weight
|
||||||
|
|
||||||
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_bert(rank, world_size, port):
|
def check_bert(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_bert_test()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -3,7 +3,14 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
@ -32,10 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check attention grad
|
# check attention grad
|
||||||
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
|
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
|
||||||
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad
|
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
@ -43,25 +54,33 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check embedding weights
|
# check embedding weights
|
||||||
org_grad = bloom.word_embeddings.weight.grad
|
org_grad = bloom.word_embeddings.weight.grad
|
||||||
shard_grad = sharded_bloom.word_embeddings.weight.grad
|
shard_grad = sharded_bloom.word_embeddings.weight.grad
|
||||||
|
shard_weight = sharded_bloom.word_embeddings.weight
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_bloom(rank, world_size, port):
|
def check_bloom(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_bloom_test()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -3,7 +3,14 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
@ -32,11 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check mlp grad
|
# check mlp grad
|
||||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||||
|
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
org_grad, all_shard_grad,
|
org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
@ -44,25 +54,33 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check embedding weights
|
# check embedding weights
|
||||||
org_grad = org_model.wte.weight.grad
|
org_grad = org_model.wte.weight.grad
|
||||||
shard_grad = sharded_model.wte.weight.grad
|
shard_grad = sharded_model.wte.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_weight = sharded_model.wte.weight
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
|
||||||
|
|
||||||
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
org_grad, all_shard_grad,
|
org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_gpt2(rank, world_size, port):
|
def check_gpt2(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_gpt2_test()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -5,7 +5,14 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
@ -37,33 +44,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check attention grad
|
# check attention grad
|
||||||
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
# check embedding grad
|
# check embedding grad
|
||||||
org_grad = llama_model.embed_tokens.weight.grad
|
org_grad = llama_model.embed_tokens.weight.grad
|
||||||
shard_grad = shard_llama_model.embed_tokens.weight.grad
|
shard_grad = shard_llama_model.embed_tokens.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
shard_weight = shard_llama_model.embed_tokens.weight
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_llama(rank, world_size, port):
|
def check_llama(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_gpt2_llama()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -6,10 +6,11 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
from colossalai.testing import (
|
from colossalai.testing import (
|
||||||
assert_hf_output_close,
|
assert_hf_output_close,
|
||||||
check_state_dict_equal,
|
|
||||||
clear_cache_before_run,
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
rerun_if_address_is_in_use,
|
rerun_if_address_is_in_use,
|
||||||
spawn,
|
spawn,
|
||||||
)
|
)
|
||||||
@ -42,32 +43,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check attention grad
|
# check attention grad
|
||||||
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
# check embedding grad
|
# check embedding grad
|
||||||
org_grad = opt_model.decoder.embed_tokens.weight.grad
|
org_grad = opt_model.decoder.embed_tokens.weight.grad
|
||||||
shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad
|
shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
shard_weight = shard_opt_model.decoder.embed_tokens.weight
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_OPTModel(rank, world_size, port):
|
def check_OPTModel(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_t5_test()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -5,7 +5,14 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
@ -27,19 +34,28 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
# check attention grad
|
# check attention grad
|
||||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||||
|
shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
# check self attention embed
|
# check self attention embed
|
||||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
||||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
@ -52,23 +68,32 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
|
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
|
||||||
|
|
||||||
shard_grad = sharded_model.shared.weight.grad
|
shard_grad = sharded_model.shared.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_weight = sharded_model.shared.weight
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
else:
|
||||||
|
all_shard_grad = shard_grad
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
|
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def check_t5(rank, world_size, port):
|
def check_t5(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_t5_test()
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
org_model, sharded_model = build_model(model_fn)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
Loading…
Reference in New Issue
Block a user