mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[upgrade]Upgrade vit (#6308)
* fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d0e13b85fd
commit
04516bb756
@ -349,7 +349,7 @@ def get_vit_flash_self_attention_forward():
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
dropout_p = self.dropout.p if self.training else 0.0
|
||||
dropout_p = self.dropout_prob if self.training else 0.0
|
||||
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
@ -82,6 +82,7 @@ def get_whisper_flash_attention_forward():
|
||||
attention_mask: Optional[dict] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
|
||||
@ -172,6 +173,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
cache_position=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -93,10 +93,6 @@ class ViTPolicy(Policy):
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
|
||||
@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
||||
|
||||
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||
|
||||
emb = LlamaRotaryEmbedding(D)
|
||||
config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
|
||||
emb = LlamaRotaryEmbedding(config)
|
||||
|
||||
cos, sin = emb(x0, position_ids)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||
|
Loading…
Reference in New Issue
Block a user