mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
change command
This commit is contained in:
@@ -7,30 +7,30 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
FusedCohereLayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
CohereLayerNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.llama import (
|
||||
LlamaPipelineForwards,
|
||||
get_llama_flash_attention_forward,
|
||||
get_llama_model_forward_for_flash_attn,
|
||||
get_llama_seq_parallel_attention_forward,
|
||||
get_llama_seq_parallel_model_forward,
|
||||
from ..modeling.command import (
|
||||
CommandPipelineForwards,
|
||||
get_command_flash_attention_forward,
|
||||
get_command_model_forward_for_flash_attn,
|
||||
get_command_seq_parallel_attention_forward,
|
||||
get_command_seq_parallel_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
class CommandPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@@ -40,18 +40,18 @@ class LlamaPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaModel,
|
||||
LlamaSdpaAttention,
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
CohereAttention,
|
||||
CohereDecoderLayer,
|
||||
CohereFlashAttention2,
|
||||
CohereModel,
|
||||
CohereSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": LlamaAttention,
|
||||
"flash_attention_2": LlamaFlashAttention2,
|
||||
"sdpa": LlamaSdpaAttention,
|
||||
"eager": CohereAttention,
|
||||
"flash_attention_2": CohereFlashAttention2,
|
||||
"sdpa": CohereSdpaAttention,
|
||||
}
|
||||
policy = {}
|
||||
|
||||
@@ -64,16 +64,16 @@ class LlamaPolicy(Policy):
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
norm_cls = FusedCohereLayerNorm
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
norm_cls = CohereLayerNorm
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
self.shard_config.enable_sequence_overlap = False
|
||||
self.shard_config.sequence_parallelism_mode = None
|
||||
warnings.warn(
|
||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||
)
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
||||
@@ -94,16 +94,16 @@ class LlamaPolicy(Policy):
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_model_forward(
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
@@ -120,21 +120,21 @@ class LlamaPolicy(Policy):
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_seq_parallel_model_forward(
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@@ -155,7 +155,7 @@ class LlamaPolicy(Policy):
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
@@ -204,7 +204,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
@@ -215,14 +215,9 @@ class LlamaPolicy(Policy):
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaDecoderLayer,
|
||||
target_key=CohereDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
@@ -232,26 +227,26 @@ class LlamaPolicy(Policy):
|
||||
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace llama model forward method
|
||||
# replace Command model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
|
||||
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
@@ -266,7 +261,7 @@ class LlamaPolicy(Policy):
|
||||
return
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
if self.model.__class__.__name__ == "CohereModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
@@ -293,7 +288,7 @@ class LlamaPolicy(Policy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
if self.model.__class__.__name__ == "CohereModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
@@ -323,15 +318,15 @@ class LlamaPolicy(Policy):
|
||||
return held_layers
|
||||
|
||||
|
||||
class LlamaModelPolicy(LlamaPolicy):
|
||||
class CommandModelPolicy(CommandPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from transformers.models.cohere.modeling_cohere import CohereModel
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy
|
||||
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
@@ -341,20 +336,20 @@ class LlamaModelPolicy(LlamaPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
"""No shared params in command model"""
|
||||
return []
|
||||
|
||||
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
class CommandForCausalLMPolicy(CommandPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers import CohereForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
CohereForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
@@ -368,12 +363,12 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
new_item[CohereForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
else:
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
CohereForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
@@ -388,7 +383,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
@@ -402,58 +397,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
command_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
0: command_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForSequenceClassification,
|
||||
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama for sequence classification model"""
|
||||
return []
|
||||
return []
|
Reference in New Issue
Block a user