[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -7,7 +7,6 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
@@ -19,14 +18,20 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy'
"BertPolicy",
"BertModelPolicy",
"BertForPreTrainingPolicy",
"BertLMdHeadModelPolicy",
"BertForMaskedLMPolicy",
"BertForNextSentencePredictionPolicy",
"BertForSequenceClassificationPolicy",
"BertForTokenClassificationPolicy",
"BertForMultipleChoicePolicy",
"BertForQuestionAnsweringPolicy",
]
class BertPolicy(Policy):
def config_sanity_check(self):
pass
@@ -58,136 +63,140 @@ class BertPolicy(Policy):
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
policy[BertLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.self.all_head_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"attention.self.num_attention_heads": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
policy[BertEmbeddings] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
]
)
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BertModel)
target_key=BertModel,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=BertLayer)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BertLayer,
)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm,
)],
description=[
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=BertEmbeddings)
target_key=BertEmbeddings,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_bert_flash_attention_forward(),
},
policy=policy,
target_key=BertSelfAttention)
self.append_or_create_method_replacement(
description={
"forward": get_bert_flash_attention_forward(),
},
policy=policy,
target_key=BertSelfAttention,
)
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BertSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BertOutput)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_self_output_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BertSelfOutput,
)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_output_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BertOutput,
)
return policy
@@ -196,31 +205,37 @@ class BertPolicy(Policy):
# optimize for tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
policy=base_policy,
target_key=BertLMPredictionHead)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
),
policy=base_policy,
target_key=BertLMPredictionHead,
)
# optimize with fused normalization
if self.shard_config.enable_fused_normalization:
# Handle bert lm prediction head
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
policy=base_policy,
target_key=BertLMPredictionHead)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
policy=base_policy,
target_key=BertLMPredictionHead,
)
return base_policy
def add_lm_prediction_policy(self, base_policy):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
method_replacement = {
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
"_save_to_state_dict": col_nn.ParallelModule._save_to_state_dict,
"_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict,
}
self.append_or_create_method_replacement(description=method_replacement,
policy=base_policy,
target_key=BertLMPredictionHead)
self.append_or_create_method_replacement(
description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead
)
return base_policy
def postprocess(self):
@@ -228,7 +243,7 @@ class BertPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "BertModel":
@@ -239,15 +254,13 @@ class BertPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
@@ -255,7 +268,7 @@ class BertPolicy(Policy):
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'BertModel':
if self.model.__class__.__name__ == "BertModel":
module = self.model
else:
module = self.model.bert
@@ -275,17 +288,17 @@ class BertPolicy(Policy):
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertModel,
new_forward=BertPipelineForwards.bert_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -300,7 +313,6 @@ class BertModelPolicy(BertPolicy):
# BertForPreTraining
class BertForPreTrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
@@ -309,10 +321,13 @@ class BertForPreTrainingPolicy(BertPolicy):
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertForPreTraining
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForPreTraining,
new_forward=BertPipelineForwards.bert_for_pretraining_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForPreTraining,
new_forward=BertPipelineForwards.bert_for_pretraining_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -329,16 +344,17 @@ class BertForPreTrainingPolicy(BertPolicy):
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
# tie weights
return [{
0: model.bert.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
}]
return [
{
0: model.bert.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight,
}
]
return []
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
@@ -347,10 +363,11 @@ class BertLMHeadModelPolicy(BertPolicy):
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertLMHeadModel
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertLMHeadModel,
new_forward=BertPipelineForwards.bert_lm_head_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -368,16 +385,17 @@ class BertLMHeadModelPolicy(BertPolicy):
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
# tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return [
{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight,
}
]
return []
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
@@ -386,10 +404,11 @@ class BertForMaskedLMPolicy(BertPolicy):
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertForMaskedLM
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForMaskedLM,
new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -407,16 +426,17 @@ class BertForMaskedLMPolicy(BertPolicy):
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
# tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return [
{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight,
}
]
return []
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
@@ -427,19 +447,22 @@ class BertForSequenceClassificationPolicy(BertPolicy):
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
BertForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
]
)
}
policy.update(addon_module)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForSequenceClassification,
new_forward=BertPipelineForwards.bert_for_sequence_classification_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForSequenceClassification,
new_forward=BertPipelineForwards.bert_for_sequence_classification_forward,
policy=policy,
)
return policy
@@ -461,7 +484,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
@@ -472,19 +494,22 @@ class BertForTokenClassificationPolicy(BertPolicy):
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
BertForTokenClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
]
)
}
policy.update(addon_module)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForTokenClassification,
new_forward=BertPipelineForwards.bert_for_token_classification_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForTokenClassification,
new_forward=BertPipelineForwards.bert_for_token_classification_forward,
policy=policy,
)
return policy
@@ -506,17 +531,19 @@ class BertForTokenClassificationPolicy(BertPolicy):
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForNextSentencePrediction,
new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForNextSentencePrediction,
new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward,
policy=policy,
)
return policy
@@ -537,7 +564,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
@@ -548,19 +574,22 @@ class BertForMultipleChoicePolicy(BertPolicy):
if self.shard_config.enable_tensor_parallelism:
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(sub_module_replacement=[
BertForMultipleChoice: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
]
)
}
policy.update(addon_module)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForMultipleChoice,
new_forward=BertPipelineForwards.bert_for_multiple_choice_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForMultipleChoice,
new_forward=BertPipelineForwards.bert_for_multiple_choice_forward,
policy=policy,
)
return policy
@@ -581,17 +610,19 @@ class BertForMultipleChoicePolicy(BertPolicy):
class BertForQuestionAnsweringPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BertForQuestionAnswering,
new_forward=BertPipelineForwards.bert_for_question_answering_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BertForQuestionAnswering,
new_forward=BertPipelineForwards.bert_for_question_answering_forward,
policy=policy,
)
return policy