[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
This commit is contained in:
flybird11111
2023-08-24 15:50:02 +08:00
committed by GitHub
parent e04436a82a
commit 3353e55c80
7 changed files with 46 additions and 21 deletions

View File

@@ -104,16 +104,20 @@ class OPTPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
})
},
policy=policy,
target_key=OPTAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=OPTDecoderLayer)
return policy