merge model and attention forward

This commit is contained in:
GuangyaoZhang
2024-06-17 08:50:05 +00:00
parent 7a2b08646f
commit 363cde6957
2 changed files with 52 additions and 242 deletions

View File

@@ -19,8 +19,6 @@ from colossalai.shardformer.layer import (
from ..modeling.command import (
CommandPipelineForwards,
get_command_flash_attention_forward,
get_command_model_forward_for_flash_attn,
get_command_seq_parallel_attention_forward,
get_command_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy,
@@ -95,7 +93,10 @@ class CommandPolicy(Policy):
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_model_forward(
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
use_flash_attention=use_flash_attention,
),
},
policy=policy,
@@ -103,7 +104,9 @@ class CommandPolicy(Policy):
)
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
@@ -120,7 +123,9 @@ class CommandPolicy(Policy):
)
self.append_or_create_method_replacement(
description={
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
@@ -131,6 +136,7 @@ class CommandPolicy(Policy):
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
use_flash_attention=use_flash_attention,
),
},
policy=policy,
@@ -234,7 +240,9 @@ class CommandPolicy(Policy):
if use_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
"forward": get_command_seq_parallel_attention_forward(
sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=attn_cls,
@@ -243,7 +251,9 @@ class CommandPolicy(Policy):
# replace Command model forward method
self.append_or_create_method_replacement(
description={
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
"forward": get_command_seq_parallel_model_forward(
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
),
},
policy=policy,
target_key=CohereModel,