change command

This commit is contained in:
GuangyaoZhang
2024-06-14 03:04:56 +00:00
parent 0b81163bc0
commit f656d61778
9 changed files with 778 additions and 435 deletions

View File

@@ -7,30 +7,30 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
FusedRMSNorm,
FusedCohereLayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
CohereLayerNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.llama import (
LlamaPipelineForwards,
get_llama_flash_attention_forward,
get_llama_model_forward_for_flash_attn,
get_llama_seq_parallel_attention_forward,
get_llama_seq_parallel_model_forward,
from ..modeling.command import (
CommandPipelineForwards,
get_command_flash_attention_forward,
get_command_model_forward_for_flash_attn,
get_command_seq_parallel_attention_forward,
get_command_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
__all__ = ["CommandPolicy", "CommandForCausalLMPolicy"]
class LlamaPolicy(Policy):
class CommandPolicy(Policy):
def config_sanity_check(self):
pass
@@ -40,18 +40,18 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaModel,
LlamaSdpaAttention,
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereFlashAttention2,
CohereModel,
CohereSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
"eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2,
"sdpa": CohereSdpaAttention,
}
policy = {}
@@ -64,16 +64,16 @@ class LlamaPolicy(Policy):
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
norm_cls = FusedCohereLayerNorm
else:
norm_cls = RMSNorm
norm_cls = CohereLayerNorm
if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False
self.shard_config.enable_sequence_overlap = False
self.shard_config.sequence_parallelism_mode = None
warnings.warn(
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
)
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
@@ -94,16 +94,16 @@ class LlamaPolicy(Policy):
if sp_mode in ["split_gather", "ring"]:
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
),
},
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
@@ -120,21 +120,21 @@ class LlamaPolicy(Policy):
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
if self.shard_config.enable_tensor_parallelism:
@@ -155,7 +155,7 @@ class LlamaPolicy(Policy):
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
policy[LlamaDecoderLayer] = ModulePolicyDescription(
policy[CohereDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
@@ -204,7 +204,7 @@ class LlamaPolicy(Policy):
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
# optimization configuration
@@ -215,14 +215,9 @@ class LlamaPolicy(Policy):
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=LlamaDecoderLayer,
target_key=CohereDecoderLayer,
)
self.append_or_create_submodule_replacement(
@@ -232,26 +227,26 @@ class LlamaPolicy(Policy):
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
# use flash attention
if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace llama model forward method
# replace Command model forward method
self.append_or_create_method_replacement(
description={
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=LlamaModel,
target_key=CohereModel,
)
return policy
@@ -266,7 +261,7 @@ class LlamaPolicy(Policy):
return
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "LlamaModel":
if self.model.__class__.__name__ == "CohereModel":
module = self.model
else:
module = self.model.model
@@ -293,7 +288,7 @@ class LlamaPolicy(Policy):
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "LlamaModel":
if self.model.__class__.__name__ == "CohereModel":
module = self.model
else:
module = self.model.model
@@ -323,15 +318,15 @@ class LlamaPolicy(Policy):
return held_layers
class LlamaModelPolicy(LlamaPolicy):
class CommandModelPolicy(CommandPolicy):
def module_policy(self):
policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.cohere.modeling_cohere import CohereModel
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy
model_cls=CohereModel, new_forward=CommandPipelineForwards.command_model_forward, policy=policy
)
return policy
@@ -341,20 +336,20 @@ class LlamaModelPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
"""No shared params in command model"""
return []
class LlamaForCausalLMPolicy(LlamaPolicy):
class CommandForCausalLMPolicy(CommandPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM
from transformers import CohereForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
@@ -368,12 +363,12 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement = {
new_item[CohereForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else:
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
@@ -388,7 +383,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
)
return policy
@@ -402,58 +397,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
command_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
id(command_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
0: command_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForSequenceClassification,
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
return []