mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
rebase master llama change
This commit is contained in:
@@ -19,8 +19,8 @@ from colossalai.shardformer.layer import (
|
||||
|
||||
from ..modeling.command import (
|
||||
CommandPipelineForwards,
|
||||
get_command_seq_parallel_attention_forward,
|
||||
get_command_seq_parallel_model_forward,
|
||||
get_command_flash_attention_forward,
|
||||
get_command_flash_attention_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
@@ -80,38 +80,7 @@ class CommandPolicy(Policy):
|
||||
)
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
# Currently sp cannot to be used with flashattention
|
||||
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
||||
if use_flash_attention:
|
||||
warnings.warn(
|
||||
f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
|
||||
)
|
||||
use_flash_attention = False
|
||||
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
use_flash_attention=use_flash_attention,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_attention_forward(
|
||||
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
}
|
||||
@@ -121,27 +90,28 @@ class CommandPolicy(Policy):
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_attention_forward(
|
||||
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
|
||||
),
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
use_flash_attention=use_flash_attention,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
@@ -236,29 +206,6 @@ class CommandPolicy(Policy):
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_attention_forward(
|
||||
sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace Command model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
Reference in New Issue
Block a user