[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -8,7 +8,6 @@ from torch import Tensor
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
WhisperPipelineForwards,
@@ -19,13 +18,14 @@ from ..modeling.whisper import (
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
'WhisperForAudioClassificationPolicy'
"WhisperPolicy",
"WhisperModelPolicy",
"WhisperForConditionalGenerationPolicy",
"WhisperForAudioClassificationPolicy",
]
class WhisperPolicy(Policy):
def config_sanity_check(self):
pass
@@ -55,179 +55,197 @@ class WhisperPolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
#TODO using the jit fused add_and_dropout affect the accuracy
# TODO using the jit fused add_and_dropout affect the accuracy
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
])
policy[WhisperEncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.encoder_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
],
)
policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size,
"encoder_attn.embed_dim":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"encoder_attn.num_heads":
self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="encoder_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
])
policy[WhisperDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.decoder_attention_heads
// self.shard_config.tensor_parallel_size,
"encoder_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
"encoder_attn.num_heads": self.model.config.encoder_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="encoder_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
],
)
policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
])
policy[WhisperDecoder] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
]
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle encoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperEncoderLayer)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=WhisperEncoderLayer,
)
# Handle decoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperDecoderLayer)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=WhisperDecoderLayer,
)
# handle encoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperEncoder)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperEncoder,
)
# handle decoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperDecoder)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperDecoder,
)
# enable flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(),
},
policy=policy,
target_key=WhisperAttention)
self.append_or_create_method_replacement(
description={
"forward": get_whisper_flash_attention_forward(),
},
policy=policy,
target_key=WhisperAttention,
)
# use jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperDecoderLayer)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperEncoderLayer)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_whisper_decoder_layer_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperDecoderLayer,
)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_whisper_encoder_layer_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperEncoderLayer,
)
return policy
@@ -236,10 +254,13 @@ class WhisperPolicy(Policy):
# optimize for tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
policy=base_policy,
target_key=WhisperForConditionalGeneration)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
),
policy=base_policy,
target_key=WhisperForConditionalGeneration,
)
return base_policy
@@ -247,8 +268,9 @@ class WhisperPolicy(Policy):
return self.model
@staticmethod
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
def distribute_whisper_layers(
num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]:
"""
Distribute whisper layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
@@ -281,8 +303,9 @@ class WhisperPolicy(Policy):
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
def get_whisper_stage_index(
layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
@@ -293,13 +316,12 @@ class WhisperPolicy(Policy):
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
if self.model.__class__.__name__ == "WhisperModel":
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
elif self.model.__class__.__name__ == "WhisperForConditionalGeneration":
model = self.model.model
else:
model = None
@@ -320,9 +342,11 @@ class WhisperPolicy(Policy):
held_layers = []
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
@@ -347,14 +371,14 @@ class WhisperPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
if self.model.__class__.__name__ == "WhisperModel":
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
elif self.model.__class__.__name__ == "WhisperForConditionalGeneration":
model = self.model.model
else:
model = None
@@ -373,34 +397,37 @@ class WhisperPolicy(Policy):
num_decoder_layers = 0
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
stage_index = WhisperPolicy.get_whisper_stage_index(
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
"forward": partial(
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperModel
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperModel,
new_forward=WhisperPipelineForwards.whisper_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy
)
return policy
@@ -414,19 +441,21 @@ class WhisperModelPolicy(WhisperPolicy):
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperForConditionalGeneration
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=WhisperForConditionalGeneration,
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
policy=policy,
)
return policy
def postprocess(self):
@@ -457,8 +486,9 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
stage_manager.num_stages)
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
shared_params = []
shared_embedding = {}
if id(module.proj_out) == id(model.decoder.embed_tokens):
@@ -472,7 +502,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
@@ -481,12 +510,15 @@ class WhisperForAudioClassificationPolicy(WhisperPolicy):
def module_policy(self):
from transformers import WhisperForAudioClassification
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=WhisperForAudioClassification,
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]: