mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 20:41:24 +00:00
feat: mem attn forward pass on cuda
This commit is contained in:
parent
18b04347f5
commit
5f28e60a9a
@ -16,14 +16,11 @@
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers import AutoModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -31,7 +28,6 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from gpt4all.models.lethe import LetheConfig
|
||||
import hnswlib
|
||||
import numpy as np
|
||||
@ -44,10 +40,9 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"EleutherAI/gpt-neox-20b",
|
||||
]
|
||||
|
||||
# TODO: understand why Phil only does this per batch and doens't persist across many batches
|
||||
# TODO: k/v are stored per head and per token!!!
|
||||
# reshape query, key, value into (bs * seq_len, num_attention_heads, head_size)
|
||||
# for each head, store index of k/v for each token
|
||||
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
|
||||
# TODO: do we need to implement masking for the dense vectors we pull from?
|
||||
# TODO: i think phil is using a memmapped database to pull out rather than using the index
|
||||
|
||||
|
||||
class HNSWIndex:
|
||||
@ -110,7 +105,6 @@ class MemoryIndex:
|
||||
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
|
||||
|
||||
for head in range(self.nheads):
|
||||
print(f"adding head {head}")
|
||||
self.key_indices[head].add(reshaped_keys[:, head, :])
|
||||
self.value_indices[head].add(reshaped_values[:, head, :])
|
||||
|
||||
@ -220,13 +214,14 @@ class LetheAttention(nn.Module):
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
):
|
||||
has_layer_past = layer_past is not None
|
||||
|
||||
# Compute QKV
|
||||
# Attention heads [batch, seq_len, hidden_size]
|
||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||
bs, seq_len, hidden_size = hidden_states.size()
|
||||
_, seq_len, _ = hidden_states.size()
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
# [batch, seq_len, (num_heads * 3 * head_size)]
|
||||
@ -265,16 +260,24 @@ class LetheAttention(nn.Module):
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
# TODO: need to do masking??
|
||||
if self.memory:
|
||||
# get knns
|
||||
# since we do an eval batch w context before, let's not do the expensive step until we need to
|
||||
# [batch, knn, num_attention_heads, seq_len, head_size]
|
||||
knn_keys, knn_values = self.index.knn_query(query.detach().numpy(), k=self.num_neighbors)
|
||||
mem_attn = self._mem_attn(query, knn_keys, knn_values, attention_mask, head_mask)
|
||||
if use_mem_attn:
|
||||
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||
mem_attn = self._mem_attn(query,
|
||||
knn_keys.to(query.device),
|
||||
knn_values.to(query.device),
|
||||
attention_mask,
|
||||
head_mask
|
||||
)
|
||||
|
||||
expanded_alpha = self.alpha[None, :, None, None]
|
||||
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
|
||||
|
||||
self.index.add(key.detach().numpy(), value.detach().numpy())
|
||||
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||
@ -462,6 +465,7 @@ class LetheLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
):
|
||||
ln_hidden_states = self.input_layernorm(hidden_states)
|
||||
attention_layer_outputs = self.attention(
|
||||
@ -472,6 +476,7 @@ class LetheLayer(nn.Module):
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
use_mem_attn=use_mem_attn,
|
||||
)
|
||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||
outputs = attention_layer_outputs[1:]
|
||||
@ -533,6 +538,7 @@ class LetheModel(LethePreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
@ -625,7 +631,7 @@ class LetheModel(LethePreTrainedModel):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for layer_past
|
||||
return module(*inputs, use_cache, None, output_attentions)
|
||||
return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@ -645,6 +651,7 @@ class LetheModel(LethePreTrainedModel):
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
use_mem_attn=use_mem_attn,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
@ -692,7 +699,6 @@ class LetheForCausalLM(LethePreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@ -703,6 +709,7 @@ class LetheForCausalLM(LethePreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
@ -744,29 +751,9 @@ class LetheForCausalLM(LethePreTrainedModel):
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
# memories are where token_type_ids == 0
|
||||
memory_mask = token_type_ids == 0
|
||||
# should be shape (num_memories, sequence_length)
|
||||
memories = input_ids[memory_mask]
|
||||
with torch.no_grad():
|
||||
# store memories but we don't back prop
|
||||
self.gpt_neox(memories,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
questions = input_ids[~memory_mask]
|
||||
answers = labels[~memory_mask]
|
||||
|
||||
outputs = self.gpt_neox(
|
||||
questions,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
@ -776,20 +763,21 @@ class LetheForCausalLM(LethePreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_mem_attn=use_mem_attn
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.embed_out(hidden_states)
|
||||
|
||||
lm_loss = None
|
||||
if answers is not None:
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
answers = answers.to(lm_logits.device)
|
||||
labels = labels.to(lm_logits.device)
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
answers = answers[:, 1:].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), answers.view(-1))
|
||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
Loading…
Reference in New Issue
Block a user