[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,8 +1,5 @@
import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
@@ -12,11 +9,10 @@ from ..modeling.blip2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['BlipPolicy', 'BlipModelPolicy']
__all__ = ["BlipPolicy", "BlipModelPolicy"]
class BlipPolicy(Policy):
def config_sanity_check(self):
pass
@@ -48,263 +44,293 @@ class BlipPolicy(Policy):
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.num_heads":
self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.embed_dim":
self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
}),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.Linear1D_Row,
),
])
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attn.embed_dim": self.model.config.vision_config.hidden_size
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.Linear1D_Row,
),
],
)
policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[Blip2QFormerModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.attention.num_attention_heads":
self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.attention.all_head_size":
self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
policy[Blip2QFormerLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size": self.model.config.qformer_config.hidden_size
// self.shard_config.tensor_parallel_size,
"crossattention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"crossattention.attention.all_head_size": self.model.config.qformer_config.hidden_size
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
)
])
policy[OPTDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.embed_dim": self.model.config.text_config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.text_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
],
)
policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
),
])
policy[OPTForCausalLM] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
),
]
)
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2EncoderLayer)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2VisionModel)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=Blip2EncoderLayer,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
suffix="layernorm",
target_module=col_nn.FusedLayerNorm,
)],
description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2QFormerModel)
target_key=Blip2VisionModel,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2QFormerModel,
)
# handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2QFormerLayer)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=Blip2QFormerLayer,
)
# handle OPTForCausalLM layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=OPTForCausalLM)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=OPTForCausalLM,
)
# handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=OPTDecoderLayer)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_blip2_flash_attention_forward(),
},
policy=policy,
target_key=Blip2Attention)
self.append_or_create_method_replacement(
description={
"forward": get_blip2_flash_attention_forward(),
},
policy=policy,
target_key=Blip2Attention,
)
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerOutput)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_QFormer_self_output_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerSelfOutput,
)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_QFormer_output_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerOutput,
)
return policy
@@ -314,13 +340,11 @@ class BlipPolicy(Policy):
# Blip2Model
class Blip2ModelPolicy(BlipPolicy):
def __init__(self) -> None:
super().__init__()
# Blip2ForConditionalGeneration
class Blip2ForConditionalGenerationPolicy(BlipPolicy):
def __init__(self) -> None:
super().__init__()