mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -62,7 +60,6 @@ def forward_fn():
|
||||
|
||||
|
||||
def get_blip2_flash_attention_forward():
|
||||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
@@ -80,10 +77,9 @@ def get_blip2_flash_attention_forward():
|
||||
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout.p,
|
||||
scale=self.scale)
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale
|
||||
)
|
||||
context_layer = attention(query_states, key_states, value_states)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
@@ -95,7 +91,6 @@ def get_blip2_flash_attention_forward():
|
||||
|
||||
|
||||
def get_jit_fused_blip2_QFormer_self_output_forward():
|
||||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
|
||||
|
||||
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -108,7 +103,6 @@ def get_jit_fused_blip2_QFormer_self_output_forward():
|
||||
|
||||
|
||||
def get_jit_fused_blip2_QFormer_output_forward():
|
||||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
|
||||
|
||||
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
Reference in New Issue
Block a user