From b032cf9b168ba165808da2bbc115dfc0d6233d7a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 14 May 2025 12:45:34 +0800 Subject: [PATCH] upgrade_sam --- colossalai/shardformer/modeling/sam.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index 49fce0556..c84395989 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -1,4 +1,5 @@ import torch +from torch import nn def forward_fn(): @@ -16,16 +17,15 @@ def forward_fn(): attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos( - attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) - # replace dropout process with added DropoutForParallelInput layer - # origin code: - # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_probs = self.dropout_layer(attn_weights) + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)