upgrade_sam

This commit is contained in:
wangbluo 2025-05-14 12:45:34 +08:00
parent 2237531137
commit b032cf9b16

View File

@ -1,4 +1,5 @@
import torch
from torch import nn
def forward_fn():
@ -16,16 +17,15 @@ def forward_fn():
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
attn_weights = attn_weights + decomposed_rel_pos
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
# replace dropout process with added DropoutForParallelInput layer
# origin code:
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = self.dropout_layer(attn_weights)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)