[Hotfix] Fix llama fwd replacement bug (#6031)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Wenxuan Tan 2024-08-23 15:44:27 +08:00 committed by GitHub
parent 39e2597426
commit 7cf9df07bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -95,19 +95,20 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=attn_cls, target_key=attn_cls,
) )
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement( if self.pipeline_stage_manager is None:
description={ self.append_or_create_method_replacement(
"forward": get_llama_flash_attention_model_forward( description={
self.shard_config, "forward": get_llama_flash_attention_model_forward(
sp_mode=sp_mode, self.shard_config,
sp_size=sp_size, sp_mode=sp_mode,
sp_group=sp_group, sp_size=sp_size,
), sp_group=sp_group,
}, ),
policy=policy, },
target_key=LlamaModel, policy=policy,
) target_key=LlamaModel,
)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (