[upgrade]Upgrade vit (#6308)

* fix

* fix

* fix rotate embedding test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2025-05-21 16:14:20 +08:00 committed by GitHub
parent d0e13b85fd
commit 04516bb756
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 6 additions and 7 deletions

View File

@ -349,7 +349,7 @@ def get_vit_flash_self_attention_forward():
value_layer = self.transpose_for_scores(self.value(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
dropout_p = self.dropout.p if self.training else 0.0 dropout_p = self.dropout_prob if self.training else 0.0
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

View File

@ -82,6 +82,7 @@ def get_whisper_flash_attention_forward():
attention_mask: Optional[dict] = None, attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
@ -172,6 +173,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (

View File

@ -93,10 +93,6 @@ class ViTPolicy(Policy):
"use_zbv": use_zbv, "use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
emb = LlamaRotaryEmbedding(D) config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
emb = LlamaRotaryEmbedding(config)
cos, sin = emb(x0, position_ids) cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)