rebase master llama change

This commit is contained in:
GuangyaoZhang
2024-06-18 02:56:47 +00:00
parent 20c0b06ff5
commit a83a2336e8
2 changed files with 272 additions and 301 deletions

View File

@@ -19,8 +19,8 @@ from colossalai.shardformer.layer import (
from ..modeling.command import (
CommandPipelineForwards,
get_command_seq_parallel_attention_forward,
get_command_seq_parallel_model_forward,
get_command_flash_attention_forward,
get_command_flash_attention_model_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -80,38 +80,7 @@ class CommandPolicy(Policy):
)
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
# Currently sp cannot to be used with flashattention
if sp_mode in ["split_gather", "ring", "all_to_all"]:
if use_flash_attention:
warnings.warn(
f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
)
use_flash_attention = False
if sp_mode in ["split_gather", "ring"]:
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
use_flash_attention=use_flash_attention,
),
},
policy=policy,
target_key=CohereModel,
)
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
)
elif sp_mode == "all_to_all":
if sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
}
@@ -121,27 +90,28 @@ class CommandPolicy(Policy):
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
use_flash_attention=use_flash_attention,
),
},
policy=policy,
target_key=CohereModel,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=CohereModel,
)
if self.shard_config.enable_tensor_parallelism:
assert (
@@ -236,29 +206,6 @@ class CommandPolicy(Policy):
target_key=CohereModel,
)
# use flash attention
if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace Command model forward method
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_model_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=CohereModel,
)
return policy
def postprocess(self):