mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
merge model and attention forward
This commit is contained in:
@@ -19,8 +19,6 @@ from colossalai.shardformer.layer import (
|
||||
|
||||
from ..modeling.command import (
|
||||
CommandPipelineForwards,
|
||||
get_command_flash_attention_forward,
|
||||
get_command_model_forward_for_flash_attn,
|
||||
get_command_seq_parallel_attention_forward,
|
||||
get_command_seq_parallel_model_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
@@ -95,7 +93,10 @@ class CommandPolicy(Policy):
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
use_flash_attention=use_flash_attention,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
@@ -103,7 +104,9 @@ class CommandPolicy(Policy):
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_command_seq_parallel_attention_forward(
|
||||
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
@@ -120,7 +123,9 @@ class CommandPolicy(Policy):
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||
"forward": get_command_seq_parallel_attention_forward(
|
||||
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
@@ -131,6 +136,7 @@ class CommandPolicy(Policy):
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
use_flash_attention=use_flash_attention,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
@@ -234,7 +240,9 @@ class CommandPolicy(Policy):
|
||||
if use_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||
"forward": get_command_seq_parallel_attention_forward(
|
||||
sp_mode, sp_group, sp_size, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
@@ -243,7 +251,9 @@ class CommandPolicy(Policy):
|
||||
# replace Command model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_command_model_forward_for_flash_attn(self.shard_config),
|
||||
"forward": get_command_seq_parallel_model_forward(
|
||||
sp_mode, sp_size, sp_group, use_flash_attention=use_flash_attention
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
|
Reference in New Issue
Block a user