[shardformer] to fix whisper test failed due to significant accuracy differences. (#4710)

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed

* [shardformer] fix whisper test failed
This commit is contained in:
flybird11111
2023-09-14 21:34:20 +08:00
committed by GitHub
parent e2c0e7f92a
commit 20190b49a5
2 changed files with 6 additions and 1 deletions

View File

@@ -57,6 +57,11 @@ class WhisperPolicy(Policy):
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
#TODO using the jit fused add_and_dropout affect the accuracy
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":