[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -13,13 +13,15 @@ from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
'OPTForQuestionAnsweringPolicy'
"OPTPolicy",
"OPTModelPolicy",
"OPTForCausalLMPolicy",
"OPTForSequenceClassificationPolicy",
"OPTForQuestionAnsweringPolicy",
]
class OPTPolicy(Policy):
def config_sanity_check(self):
pass
@@ -45,79 +47,94 @@ class OPTPolicy(Policy):
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
])
policy[OPTDecoder] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]
)
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
),
]
)
policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
),
])
policy[OPTAttention] = ModulePolicyDescription(
attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
),
],
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True),
policy=policy,
target_key=OPTDecoder)
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True),
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True)
],
policy=policy,
target_key=OPTDecoderLayer)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
},
policy=policy,
target_key=OPTAttention)
self.append_or_create_method_replacement(
description={
"forward": get_opt_flash_attention_forward(),
},
policy=policy,
target_key=OPTAttention,
)
# use jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=OPTDecoderLayer)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_opt_decoder_layer_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=OPTDecoderLayer,
)
return policy
@@ -128,7 +145,7 @@ class OPTPolicy(Policy):
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'OPTModel':
if self.model.__class__.__name__ == "OPTModel":
module = self.model.decoder
else:
module = self.model.model.decoder
@@ -149,24 +166,23 @@ class OPTPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'OPTModel':
if self.model.__class__.__name__ == "OPTModel":
module = self.model.decoder
else:
module = self.model.model.decoder
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
@@ -175,9 +191,9 @@ class OPTModelPolicy(OPTPolicy):
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTModel,
new_forward=OPTPipelineForwards.opt_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[nn.Module]:
@@ -189,20 +205,22 @@ class OPTModelPolicy(OPTPolicy):
class OPTForCausalLMPolicy(OPTPolicy):
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
),
policy=policy,
target_key=OPTForCausalLM,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy
)
return policy
@@ -223,7 +241,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {
'model.decoder.embed_tokens': 'lm_head',
"model.decoder.embed_tokens": "lm_head",
}
for k, v in binding_map.items():
@@ -235,7 +253,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
@@ -244,9 +261,11 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForSequenceClassification,
new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=OPTForSequenceClassification,
new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward,
policy=policy,
)
return policy
@@ -262,7 +281,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
@@ -271,9 +289,11 @@ class OPTForQuestionAnsweringPolicy(OPTPolicy):
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForQuestionAnswering,
new_forward=OPTPipelineForwards.opt_for_question_answering_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=OPTForQuestionAnswering,
new_forward=OPTPipelineForwards.opt_for_question_answering_forward,
policy=policy,
)
return policy