mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-05 03:27:09 +00:00
fix: batched xattn
This commit is contained in:
parent
ca66d12d89
commit
e255e0a805
@ -296,7 +296,7 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
|
||||
return tensor.permute(0, 2, 1)
|
||||
return tensor.permute(0, 1, 3, 2)
|
||||
|
||||
|
||||
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
@ -304,6 +304,7 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||
"""
|
||||
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
@ -316,10 +317,10 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
head_mask=None,
|
||||
):
|
||||
|
||||
# query -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
# key -> (bs, num_attention_heads, head_dim)
|
||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
||||
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||
attn_weights = torch.matmul(query, key)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
@ -329,10 +330,10 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
# value -> (bs, num_attention_heads, head_dim)
|
||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
||||
# value -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
attn_output = torch.matmul(attn_weights, value.transpose(-1, -2))
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
@ -362,8 +363,8 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
|
||||
|
||||
|
||||
key = key.permute(0, 2, 1)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 3, 1, 2)
|
||||
key = key.permute(0, 3, 2, 1)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key = layer_past[0]
|
||||
@ -454,30 +455,25 @@ class GPTJRBlock(nn.Module):
|
||||
self_attention_residual = attn_output + feed_forward_hidden_states + residual
|
||||
|
||||
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
||||
# may not need, can norm encodings
|
||||
encoder_normed = self.ln_2(encoder_hidden_states)
|
||||
|
||||
num_neighbors = encoder_normed.shape[1]
|
||||
cross_attn_outputs = []
|
||||
for k in range(num_neighbors):
|
||||
# cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
cross_attn_output = self.cross_attn(
|
||||
residual,
|
||||
encoder_hidden_states=encoder_normed[:, k, :],
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
cross_attn_outputs.append(cross_attn_output[0])
|
||||
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||
cross_attn_output = self.cross_attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_normed,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1)
|
||||
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
|
||||
cross_attn_ff = self.cross_attn_mlp(
|
||||
cross_attn_output
|
||||
cross_attn_output[0]
|
||||
)
|
||||
|
||||
alpha = self.alpha if self.training else 0.5
|
||||
|
||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
|
||||
if use_cache:
|
||||
|
Loading…
Reference in New Issue
Block a user