mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,11 +11,10 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@@ -40,15 +39,15 @@ class LlamaPolicy(Policy):
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
@@ -80,45 +79,53 @@ class LlamaPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaDecoderLayer)
|
||||
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel)
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_llama_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -127,7 +134,7 @@ class LlamaPolicy(Policy):
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
@@ -137,10 +144,10 @@ class LlamaPolicy(Policy):
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@@ -148,7 +155,7 @@ class LlamaPolicy(Policy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == 'LlamaModel':
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
@@ -167,18 +174,18 @@ class LlamaPolicy(Policy):
|
||||
|
||||
|
||||
class LlamaModelPolicy(LlamaPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(model_cls=LlamaModel,
|
||||
new_forward=LlamaPipelineForwards.llama_model_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
@@ -192,7 +199,6 @@ class LlamaModelPolicy(LlamaPolicy):
|
||||
|
||||
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
@@ -201,19 +207,21 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
|
||||
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -228,18 +236,21 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if id(llama_model.embed_tokens.weight) == id(
|
||||
self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
|
||||
}]
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
@@ -248,19 +259,23 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(model_cls=LlamaForSequenceClassification,
|
||||
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForSequenceClassification,
|
||||
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
|
Reference in New Issue
Block a user