mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[shardformer] to fix whisper test failed due to significant accuracy differences. (#4710)
* [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed
This commit is contained in:
parent
e2c0e7f92a
commit
20190b49a5
@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer:
|
|||||||
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||||
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||||
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] |
|
||||||
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
|
@ -57,6 +57,11 @@ class WhisperPolicy(Policy):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
|
#TODO using the jit fused add_and_dropout affect the accuracy
|
||||||
|
if self.shard_config.enable_jit_fused:
|
||||||
|
self.shard_config.enable_jit_fused = False
|
||||||
|
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||||
"self_attn.embed_dim":
|
"self_attn.embed_dim":
|
||||||
|
Loading…
Reference in New Issue
Block a user