mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-16 16:32:52 +00:00
upgrade_sam
This commit is contained in:
parent
2237531137
commit
b032cf9b16
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def forward_fn():
|
||||
@ -16,16 +17,15 @@ def forward_fn():
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
decomposed_rel_pos = self.get_decomposed_rel_pos(
|
||||
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
|
||||
attn_weights = attn_weights + decomposed_rel_pos
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
|
||||
# replace dropout process with added DropoutForParallelInput layer
|
||||
# origin code:
|
||||
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_probs = self.dropout_layer(attn_weights)
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
||||
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
||||
|
Loading…
Reference in New Issue
Block a user