[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,6 +1,6 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
import numpy as np
from torch import Tensor, nn
@@ -15,7 +15,6 @@ from colossalai.shardformer.layer import (
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.t5 import (
T5PipelineForwards,
@@ -30,7 +29,6 @@ __all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationP
class T5BasePolicy(Policy):
def config_sanity_check(self):
pass
@@ -65,151 +63,181 @@ class T5BasePolicy(Policy):
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
])
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5Stack] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
]
)
policy[T5LayerSelfAttention] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5LayerCrossAttention] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]
)
policy[T5Attention] = ModulePolicyDescription(
attribute_replacement={
"d_model": self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads": self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim": self.model.config.num_heads
* self.model.config.d_kv
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True,
),
],
)
policy[T5LayerFF] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5DenseGatedActDense] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5DenseActDense] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_t5_flash_attention_forward(),
},
policy=policy,
target_key=T5Attention)
self.append_or_create_method_replacement(
description={
"forward": get_t5_flash_attention_forward(),
},
policy=policy,
target_key=T5Attention,
)
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF)
self.append_or_create_method_replacement(description={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_method_replacement(description={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerCrossAttention)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_T5_layer_ff_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_self_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_cross_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerCrossAttention,
)
return policy
@@ -217,8 +245,9 @@ class T5BasePolicy(Policy):
return self.model
@staticmethod
def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
def distribute_t5_layers(
num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]:
"""
Distribute t5 layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
@@ -251,8 +280,9 @@ class T5BasePolicy(Policy):
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_t5_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
def get_t5_stage_index(
layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
@@ -269,16 +299,18 @@ class T5BasePolicy(Policy):
model = self.model
encoder = self.model.encoder
decoder = getattr(self.model, 'decoder', None)
decoder = getattr(self.model, "decoder", None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = []
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
@@ -303,47 +335,51 @@ class T5BasePolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
encoder = self.model.encoder
decoder = getattr(self.model, 'decoder', None)
decoder = getattr(self.model, "decoder", None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
"forward": partial(
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5Model
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=T5Model)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=T5Model,
)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
@@ -356,9 +392,9 @@ class T5ModelPolicy(T5BasePolicy):
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
)
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
@@ -366,7 +402,6 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
@@ -376,22 +411,26 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
],
policy=policy,
target_key=T5ForConditionalGeneration)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
),
],
policy=policy,
target_key=T5ForConditionalGeneration,
)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=T5ForConditionalGeneration,
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]:
@@ -404,9 +443,9 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
)
shared_params = []
shared_embedding = {}
@@ -427,7 +466,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
@@ -437,17 +475,19 @@ class T5EncoderPolicy(T5BasePolicy):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=T5EncoderModel)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=T5EncoderModel,
)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5EncoderModel,
new_forward=T5PipelineForwards.t5_encoder_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=T5EncoderModel, new_forward=T5PipelineForwards.t5_encoder_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[nn.Module]: