mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[upgrade]Upgrade vit (#6308)
* fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d0e13b85fd
commit
04516bb756
@ -349,7 +349,7 @@ def get_vit_flash_self_attention_forward():
|
|||||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
dropout_p = self.dropout.p if self.training else 0.0
|
dropout_p = self.dropout_prob if self.training else 0.0
|
||||||
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
|
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
|
||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
@ -82,6 +82,7 @@ def get_whisper_flash_attention_forward():
|
|||||||
attention_mask: Optional[dict] = None,
|
attention_mask: Optional[dict] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
|
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
|
||||||
@ -172,6 +173,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
cache_position=None,
|
||||||
):
|
):
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -93,10 +93,6 @@ class ViTPolicy(Policy):
|
|||||||
"use_zbv": use_zbv,
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="attention.attention.dropout",
|
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dense",
|
suffix="attention.output.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
|
||||||
@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
|||||||
|
|
||||||
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||||
|
|
||||||
emb = LlamaRotaryEmbedding(D)
|
config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
|
||||||
|
emb = LlamaRotaryEmbedding(config)
|
||||||
|
|
||||||
cos, sin = emb(x0, position_ids)
|
cos, sin = emb(x0, position_ids)
|
||||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||||
|
Loading…
Reference in New Issue
Block a user