diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index b1a5c4143..5106d97cf 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -349,7 +349,7 @@ def get_vit_flash_self_attention_forward(): value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - dropout_p = self.dropout.p if self.training else 0.0 + dropout_p = self.dropout_prob if self.training else 0.0 context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cf925983b..619bbc98e 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -82,6 +82,7 @@ def get_whisper_flash_attention_forward(): attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" @@ -172,6 +173,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig): output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 7b7dbf555..420ea286f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -93,10 +93,6 @@ class ViTPolicy(Policy): "use_zbv": use_zbv, }, ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 57a82647d..ab3f04c05 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -1,7 +1,7 @@ import numpy as np import pytest import torch -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.kernel_loader import InferenceOpsLoader @@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) - emb = LlamaRotaryEmbedding(D) + config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D) + emb = LlamaRotaryEmbedding(config) cos, sin = emb(x0, position_ids) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)