[upgrade]Upgrade vit (#6308)

* fix

* fix

* fix rotate embedding test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2025-05-21 16:14:20 +08:00
committed by GitHub
parent d0e13b85fd
commit 04516bb756
4 changed files with 6 additions and 7 deletions

View File

@@ -82,6 +82,7 @@ def get_whisper_flash_attention_forward():
attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
@@ -172,6 +173,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (