[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,19 +1,12 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from transformers.modeling_outputs import BaseModelOutputWithPast
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
)
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn,
@@ -23,11 +16,10 @@ from ..modeling.chatglm2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
class ChatGLMPolicy(Policy):
def config_sanity_check(self):
pass
@@ -44,12 +36,11 @@ class ChatGLMPolicy(Policy):
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, 'batch_size_dim', bsz_dim)
setattr(self.model, "batch_size_dim", bsz_dim)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {}
@@ -57,111 +48,129 @@ class ChatGLMPolicy(Policy):
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
policy[ChatGLMModel] = ModulePolicyDescription(
attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
],
)
policy[GLMBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.projection_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition":
self.model.config.kv_channels * self.model.config.num_attention_heads //
self.shard_config.tensor_parallel_size,
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attention.projection_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads
)
// self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={
'seq_parallel': use_sequence_parallel,
'seq_parallel_dim': 0,
'overlap': overlap
}),
SubModuleReplacementDescription(suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
'seq_parallel': use_sequence_parallel,
'seq_parallel_dim': 0
}),
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0},
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
],
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
SubModuleReplacementDescription(suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm)
],
policy=policy,
target_key=GLMBlock)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
target_module=col_nn.FusedLayerNorm)
],
policy=policy,
target_key=ChatGLMModel)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm
)
],
policy=policy,
target_key=ChatGLMModel,
)
else:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(suffix="post_attention_layernorm",
target_module=col_nn.FusedRMSNorm)
],
policy=policy,
target_key=GLMBlock)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
target_module=col_nn.FusedRMSNorm)
],
policy=policy,
target_key=ChatGLMModel)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
)
],
policy=policy,
target_key=ChatGLMModel,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_flash_core_attention_forward(),
},
policy=policy,
target_key=CoreAttention)
self.append_or_create_method_replacement(
description={
"forward": get_flash_core_attention_forward(),
},
policy=policy,
target_key=CoreAttention,
)
# use sequence parallel
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=ChatGLMModel)
target_key=ChatGLMModel,
)
# use jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_glm_block_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=GLMBlock)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_glm_block_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=GLMBlock,
)
return policy
@@ -172,7 +181,7 @@ class ChatGLMPolicy(Policy):
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'ChatGLMModel':
if self.model.__class__.__name__ == "ChatGLMModel":
module = self.model
else:
module = self.model.transformer
@@ -195,11 +204,11 @@ class ChatGLMPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'ChatGLMModel':
if self.model.__class__.__name__ == "ChatGLMModel":
module = self.model
else:
module = self.model.transformer
@@ -207,29 +216,26 @@ class ChatGLMPolicy(Policy):
layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
class ChatGLMModelPolicy(ChatGLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
pass
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ChatGLMModel,
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[nn.Module]:
@@ -241,14 +247,15 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration,
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration,
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]: