mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,20 +7,23 @@ from torch import Tensor
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
||||
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
|
||||
-1).permute(2, 0, 3, 1, 4))
|
||||
qkv = (
|
||||
self.qkv(hidden_states)
|
||||
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
||||
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
||||
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w,
|
||||
(height, width), (height, width))
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
|
||||
@@ -45,8 +48,8 @@ def forward_fn():
|
||||
|
||||
|
||||
def get_sam_flash_attention_forward():
|
||||
|
||||
from transformers.models.sam.modeling_sam import SamAttention
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
@@ -62,11 +65,9 @@ def get_sam_flash_attention_forward():
|
||||
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||
|
||||
def forward(self: SamAttention,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
attention_similarity: Tensor = None) -> Tensor:
|
||||
def forward(
|
||||
self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
|
||||
) -> Tensor:
|
||||
# Input projections
|
||||
query = self.q_proj(query)
|
||||
key = self.k_proj(key)
|
||||
@@ -96,8 +97,8 @@ def get_sam_flash_attention_forward():
|
||||
|
||||
|
||||
def get_sam_vision_flash_attention_forward():
|
||||
|
||||
from transformers.models.sam.modeling_sam import SamVisionAttention
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
@@ -181,8 +182,11 @@ def get_sam_vision_flash_attention_forward():
|
||||
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
||||
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
|
||||
-1).permute(2, 0, 1, 3, 4))
|
||||
qkv = (
|
||||
self.qkv(hidden_states)
|
||||
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
|
||||
.permute(2, 0, 1, 3, 4)
|
||||
)
|
||||
|
||||
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
|
||||
|
||||
|
Reference in New Issue
Block a user