[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List
import torch.nn as nn
from torch import Tensor
@@ -7,7 +7,6 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
@@ -22,7 +21,6 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
class BloomPolicy(Policy):
def config_sanity_check(self):
pass
@@ -47,39 +45,41 @@ class BloomPolicy(Policy):
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={'seq_parallel': use_sequence_parallel}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
kwargs={'seq_parallel': use_sequence_parallel}),
])
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
],
)
policy[BloomModel] = ModulePolicyDescription(
attribute_replacement={
@@ -93,72 +93,86 @@ class BloomPolicy(Policy):
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
],
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# handle bloom model
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=BloomModel)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BloomModel,
)
# handle bloom block
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=BloomBlock)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BloomBlock,
)
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)},
description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BloomModel)
target_key=BloomModel,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_bloom_flash_attention_forward(),
'dropout_add': get_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention)
self.append_or_create_method_replacement(
description={
"forward": get_bloom_flash_attention_forward(),
"dropout_add": get_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention,
)
# enable jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BloomMLP)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
},
policy=policy,
target_key=BloomGelu)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bloom_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention,
)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bloom_mlp_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=BloomMLP,
)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bloom_gelu_forward(),
"bloom_gelu_forward": get_jit_fused_gelu_forward_func(),
},
policy=policy,
target_key=BloomGelu,
)
return policy
@@ -167,7 +181,7 @@ class BloomPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "BloomModel":
@@ -178,22 +192,20 @@ class BloomPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'BloomModel':
if self.model.__class__.__name__ == "BloomModel":
module = self.model
else:
module = self.model.transformer
@@ -213,17 +225,17 @@ class BloomPolicy(Policy):
class BloomModelPolicy(BloomPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BloomModel,
new_forward=BloomPipelineForwards.bloom_model_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BloomModel, new_forward=BloomPipelineForwards.bloom_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -234,26 +246,29 @@ class BloomModelPolicy(BloomPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''no shared params in bloom model'''
"""no shared params in bloom model"""
return []
class BloomForCausalLMPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForCausalLM)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
policy=policy,
target_key=BloomForCausalLM,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BloomForCausalLM,
new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BloomForCausalLM, new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -269,29 +284,36 @@ class BloomForCausalLMPolicy(BloomPolicy):
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight):
# tie weights
return [{
0: bloom_model.transformer.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight
}]
return [
{
0: bloom_model.transformer.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight,
}
]
return []
class BloomForSequenceClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForSequenceClassification)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
policy=policy,
target_key=BloomForSequenceClassification,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BloomForSequenceClassification,
new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BloomForSequenceClassification,
new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
@@ -308,28 +330,32 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
class BloomForTokenClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=BloomForTokenClassification)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=BloomForTokenClassification,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BloomForTokenClassification,
new_forward=BloomPipelineForwards.bloom_for_token_classification_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BloomForTokenClassification,
new_forward=BloomPipelineForwards.bloom_for_token_classification_forward,
policy=policy,
)
return policy
@@ -351,11 +377,14 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy):
# No head sharding as the output features is only 2
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=BloomForQuestionAnswering,
new_forward=BloomPipelineForwards.bloom_for_question_answering_forward,
policy=policy)
self.set_pipeline_forward(
model_cls=BloomForQuestionAnswering,
new_forward=BloomPipelineForwards.bloom_for_question_answering_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]: