[shardformer] chatglm support sequence parallel (#4482)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix
This commit is contained in:
flybird11111
2023-08-22 23:59:31 +08:00
committed by GitHub
parent 351351a36e
commit 59e252ecdb
11 changed files with 259 additions and 94 deletions

View File

@@ -105,9 +105,11 @@ class LlamaPolicy(Policy):
target_key=LlamaModel)
if self.shard_config.enable_flash_attention:
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_llama_flash_attention_forward(),
})
},
policy=policy,
target_key=LlamaAttention)
return policy