mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 20:09:58 +00:00
fix: batched xattn
This commit is contained in:
parent
ca66d12d89
commit
e255e0a805
@ -296,7 +296,7 @@ class GPTJRCrossAttention(GPTJRAttention):
|
|||||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||||
tensor = tensor.view(new_shape)
|
tensor = tensor.view(new_shape)
|
||||||
|
|
||||||
return tensor.permute(0, 2, 1)
|
return tensor.permute(0, 1, 3, 2)
|
||||||
|
|
||||||
|
|
||||||
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||||
@ -304,6 +304,7 @@ class GPTJRCrossAttention(GPTJRAttention):
|
|||||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||||
"""
|
"""
|
||||||
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
|
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
|
||||||
|
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||||
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||||
return tensor.view(new_shape)
|
return tensor.view(new_shape)
|
||||||
|
|
||||||
@ -316,10 +317,10 @@ class GPTJRCrossAttention(GPTJRAttention):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# query -> (bs, seq_len, num_attention_heads, head_dim)
|
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
# key -> (bs, num_attention_heads, head_dim)
|
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
||||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key)
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
@ -329,10 +330,10 @@ class GPTJRCrossAttention(GPTJRAttention):
|
|||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
attn_weights = attn_weights * head_mask
|
attn_weights = attn_weights * head_mask
|
||||||
|
|
||||||
# value -> (bs, num_attention_heads, head_dim)
|
# value -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||||
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
|
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
attn_output = torch.matmul(attn_weights, value.transpose(-1, -2))
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
@ -362,8 +363,8 @@ class GPTJRCrossAttention(GPTJRAttention):
|
|||||||
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
|
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
|
||||||
|
|
||||||
|
|
||||||
key = key.permute(0, 2, 1)
|
value = value.permute(0, 3, 1, 2)
|
||||||
query = query.permute(0, 2, 1, 3)
|
key = key.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
past_key = layer_past[0]
|
||||||
@ -454,30 +455,25 @@ class GPTJRBlock(nn.Module):
|
|||||||
self_attention_residual = attn_output + feed_forward_hidden_states + residual
|
self_attention_residual = attn_output + feed_forward_hidden_states + residual
|
||||||
|
|
||||||
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
||||||
|
# may not need, can norm encodings
|
||||||
encoder_normed = self.ln_2(encoder_hidden_states)
|
encoder_normed = self.ln_2(encoder_hidden_states)
|
||||||
|
|
||||||
num_neighbors = encoder_normed.shape[1]
|
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||||
cross_attn_outputs = []
|
|
||||||
for k in range(num_neighbors):
|
|
||||||
# cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim)
|
|
||||||
cross_attn_output = self.cross_attn(
|
cross_attn_output = self.cross_attn(
|
||||||
residual,
|
hidden_states,
|
||||||
encoder_hidden_states=encoder_normed[:, k, :],
|
encoder_hidden_states=encoder_normed,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
cross_attn_outputs.append(cross_attn_output[0])
|
|
||||||
|
|
||||||
cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1)
|
|
||||||
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
|
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
|
||||||
cross_attn_ff = self.cross_attn_mlp(
|
cross_attn_ff = self.cross_attn_mlp(
|
||||||
cross_attn_output
|
cross_attn_output[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
alpha = self.alpha if self.training else 0.5
|
alpha = self.alpha if self.training else 0.5
|
||||||
|
|
||||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
Loading…
Reference in New Issue
Block a user