fix: batched xattn

This commit is contained in:
Zach Nussbaum 2023-04-21 21:54:47 +00:00
parent ca66d12d89
commit e255e0a805

View File

@ -296,7 +296,7 @@ class GPTJRCrossAttention(GPTJRAttention):
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1)
return tensor.permute(0, 1, 3, 2)
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
@ -304,6 +304,7 @@ class GPTJRCrossAttention(GPTJRAttention):
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
@ -316,10 +317,10 @@ class GPTJRCrossAttention(GPTJRAttention):
head_mask=None,
):
# query -> (bs, seq_len, num_attention_heads, head_dim)
# key -> (bs, num_attention_heads, head_dim)
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
# query -> (bs, num_attention_heads, seq_len, head_dim)
# key -> (bs, num_attention_heads, head_dim, neighbors)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
attn_weights = torch.matmul(query, key)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
@ -329,10 +330,10 @@ class GPTJRCrossAttention(GPTJRAttention):
if head_mask is not None:
attn_weights = attn_weights * head_mask
# value -> (bs, num_attention_heads, head_dim)
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
# value -> (bs, num_attention_heads, seq_len, head_dim)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
attn_output = torch.matmul(attn_weights, value.transpose(-1, -2))
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
@ -362,8 +363,8 @@ class GPTJRCrossAttention(GPTJRAttention):
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
key = key.permute(0, 2, 1)
query = query.permute(0, 2, 1, 3)
value = value.permute(0, 3, 1, 2)
key = key.permute(0, 3, 2, 1)
if layer_past is not None:
past_key = layer_past[0]
@ -454,30 +455,25 @@ class GPTJRBlock(nn.Module):
self_attention_residual = attn_output + feed_forward_hidden_states + residual
# encoder_hidden_states -> (bs, knn, encoder_dim)
# may not need, can norm encodings
encoder_normed = self.ln_2(encoder_hidden_states)
num_neighbors = encoder_normed.shape[1]
cross_attn_outputs = []
for k in range(num_neighbors):
# cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim)
cross_attn_output = self.cross_attn(
residual,
encoder_hidden_states=encoder_normed[:, k, :],
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
cross_attn_outputs.append(cross_attn_output[0])
# cross_attn_outputs -> (bs, seq_len, dim)
cross_attn_output = self.cross_attn(
hidden_states,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1)
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output
cross_attn_output[0]
)
alpha = self.alpha if self.training else 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if use_cache: