Merge pull request #6306 from wangbluo/upgrade_sam

Upgrade sam
This commit is contained in:
Hanks 2025-05-22 14:19:20 +08:00 committed by GitHub
commit 33614b84ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,8 @@
import torch import torch
from torch import nn
# Same as the SamVisionAttention forward method in the v4.51.3 transformers
def forward_fn(): def forward_fn():
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape batch_size, height, width, _ = hidden_states.shape
@ -16,16 +18,15 @@ def forward_fn():
attn_weights = (query * self.scale) @ key.transpose(-2, -1) attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos: if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos( decomposed_rel_pos = self.get_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
) )
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
attn_weights = attn_weights + decomposed_rel_pos
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
# replace dropout process with added DropoutForParallelInput layer attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# origin code:
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = self.dropout_layer(attn_weights)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)