[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -5,18 +5,20 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
"GPT2Policy",
"GPT2ModelPolicy",
"GPT2LMHeadModelPolicy",
"GPT2DoubleHeadsModelPolicy",
"GPT2ForTokenClassificationPolicy",
"GPT2ForSequenceClassificationPolicy",
]
class GPT2Policy(Policy):
def config_sanity_check(self):
pass
@@ -40,16 +42,18 @@ class GPT2Policy(Policy):
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[GPT2Block] = ModulePolicyDescription(
attribute_replacement={
@@ -61,31 +65,27 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
@@ -98,39 +98,46 @@ class GPT2Policy(Policy):
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
],
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
policy=policy,
target_key=GPT2Model)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
policy=policy,
target_key=GPT2Model,
)
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(suffix="ln_cross_attn",
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
],
policy=policy,
target_key=GPT2Block)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=GPT2Block,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_gpt2_flash_attention_forward(),
},
policy=policy,
target_key=GPT2Attention)
self.append_or_create_method_replacement(
description={
"forward": get_gpt2_flash_attention_forward(),
},
policy=policy,
target_key=GPT2Attention,
)
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
@@ -144,7 +151,7 @@ class GPT2Policy(Policy):
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'GPT2Model':
if self.model.__class__.__name__ == "GPT2Model":
module = self.model
else:
module = self.model.transformer
@@ -164,11 +171,11 @@ class GPT2Policy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'GPT2Model':
if self.model.__class__.__name__ == "GPT2Model":
module = self.model
else:
module = self.model.transformer
@@ -176,18 +183,15 @@ class GPT2Policy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
@@ -197,9 +201,9 @@ class GPT2ModelPolicy(GPT2Policy):
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2Model,
new_forward=GPT2PipelineForwards.gpt2_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[nn.Module]:
@@ -212,7 +216,6 @@ class GPT2ModelPolicy(GPT2Policy):
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
@@ -223,18 +226,22 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(sub_module_replacement=[
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy)
self.set_pipeline_forward(
model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy,
)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
@@ -244,7 +251,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
"""The weights of wte and lm_head are shared."""
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None:
@@ -256,7 +263,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
@@ -267,18 +273,22 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(sub_module_replacement=[
GPT2DoubleHeadsModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel,
new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
policy=module_policy)
self.set_pipeline_forward(
model_cls=GPT2DoubleHeadsModel,
new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
policy=module_policy,
)
return module_policy
@@ -295,7 +305,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
"""The weights of wte and lm_head are shared."""
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None:
@@ -307,7 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
@@ -317,9 +326,11 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
module_policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
policy=module_policy)
self.set_pipeline_forward(
model_cls=GPT2ForQuestionAnswering,
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
policy=module_policy,
)
return module_policy
@@ -330,13 +341,12 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''No shared_params in gpt2 for QA.'''
"""No shared_params in gpt2 for QA."""
return []
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
@@ -347,17 +357,20 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2ForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
GPT2ForTokenClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
])
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2ForTokenClassification,
new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
policy=module_policy)
self.set_pipeline_forward(
model_cls=GPT2ForTokenClassification,
new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
policy=module_policy,
)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
@@ -374,7 +387,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
@@ -384,9 +396,11 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
module_policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification,
new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
policy=module_policy)
self.set_pipeline_forward(
model_cls=GPT2ForSequenceClassification,
new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
policy=module_policy,
)
return module_policy
def get_held_layers(self) -> List[nn.Module]: