mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 20:41:24 +00:00
feat: mem attn forward pass on cuda
This commit is contained in:
parent
18b04347f5
commit
5f28e60a9a
@ -16,14 +16,11 @@
|
|||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import math
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from transformers import AutoModel
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
@ -31,7 +28,6 @@ from transformers.modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
|
||||||
from gpt4all.models.lethe import LetheConfig
|
from gpt4all.models.lethe import LetheConfig
|
||||||
import hnswlib
|
import hnswlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -44,10 +40,9 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
"EleutherAI/gpt-neox-20b",
|
"EleutherAI/gpt-neox-20b",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: understand why Phil only does this per batch and doens't persist across many batches
|
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
|
||||||
# TODO: k/v are stored per head and per token!!!
|
# TODO: do we need to implement masking for the dense vectors we pull from?
|
||||||
# reshape query, key, value into (bs * seq_len, num_attention_heads, head_size)
|
# TODO: i think phil is using a memmapped database to pull out rather than using the index
|
||||||
# for each head, store index of k/v for each token
|
|
||||||
|
|
||||||
|
|
||||||
class HNSWIndex:
|
class HNSWIndex:
|
||||||
@ -110,7 +105,6 @@ class MemoryIndex:
|
|||||||
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
|
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
|
||||||
|
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
print(f"adding head {head}")
|
|
||||||
self.key_indices[head].add(reshaped_keys[:, head, :])
|
self.key_indices[head].add(reshaped_keys[:, head, :])
|
||||||
self.value_indices[head].add(reshaped_values[:, head, :])
|
self.value_indices[head].add(reshaped_values[:, head, :])
|
||||||
|
|
||||||
@ -220,13 +214,14 @@ class LetheAttention(nn.Module):
|
|||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_mem_attn: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
has_layer_past = layer_past is not None
|
has_layer_past = layer_past is not None
|
||||||
|
|
||||||
# Compute QKV
|
# Compute QKV
|
||||||
# Attention heads [batch, seq_len, hidden_size]
|
# Attention heads [batch, seq_len, hidden_size]
|
||||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||||
bs, seq_len, hidden_size = hidden_states.size()
|
_, seq_len, _ = hidden_states.size()
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
# [batch, seq_len, (num_heads * 3 * head_size)]
|
# [batch, seq_len, (num_heads * 3 * head_size)]
|
||||||
@ -265,16 +260,24 @@ class LetheAttention(nn.Module):
|
|||||||
# Compute attention
|
# Compute attention
|
||||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
|
||||||
|
# TODO: need to do masking??
|
||||||
if self.memory:
|
if self.memory:
|
||||||
# get knns
|
# get knns
|
||||||
|
# since we do an eval batch w context before, let's not do the expensive step until we need to
|
||||||
# [batch, knn, num_attention_heads, seq_len, head_size]
|
# [batch, knn, num_attention_heads, seq_len, head_size]
|
||||||
knn_keys, knn_values = self.index.knn_query(query.detach().numpy(), k=self.num_neighbors)
|
if use_mem_attn:
|
||||||
mem_attn = self._mem_attn(query, knn_keys, knn_values, attention_mask, head_mask)
|
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||||
|
mem_attn = self._mem_attn(query,
|
||||||
|
knn_keys.to(query.device),
|
||||||
|
knn_values.to(query.device),
|
||||||
|
attention_mask,
|
||||||
|
head_mask
|
||||||
|
)
|
||||||
|
|
||||||
expanded_alpha = self.alpha[None, :, None, None]
|
expanded_alpha = self.alpha[None, :, None, None]
|
||||||
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
|
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
|
||||||
|
|
||||||
self.index.add(key.detach().numpy(), value.detach().numpy())
|
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||||
|
|
||||||
# Reshape outputs
|
# Reshape outputs
|
||||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||||
@ -462,6 +465,7 @@ class LetheLayer(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_mem_attn: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
ln_hidden_states = self.input_layernorm(hidden_states)
|
ln_hidden_states = self.input_layernorm(hidden_states)
|
||||||
attention_layer_outputs = self.attention(
|
attention_layer_outputs = self.attention(
|
||||||
@ -472,6 +476,7 @@ class LetheLayer(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
use_mem_attn=use_mem_attn,
|
||||||
)
|
)
|
||||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||||
outputs = attention_layer_outputs[1:]
|
outputs = attention_layer_outputs[1:]
|
||||||
@ -533,6 +538,7 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
use_mem_attn: Optional[bool] = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
@ -625,7 +631,7 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for layer_past
|
# None for layer_past
|
||||||
return module(*inputs, use_cache, None, output_attentions)
|
return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@ -645,6 +651,7 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
use_mem_attn=use_mem_attn,
|
||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
@ -692,7 +699,6 @@ class LetheForCausalLM(LethePreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
token_type_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -703,6 +709,7 @@ class LetheForCausalLM(LethePreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
use_mem_attn: Optional[bool] = True,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
@ -744,29 +751,9 @@ class LetheForCausalLM(LethePreTrainedModel):
|
|||||||
>>> prediction_logits = outputs.logits
|
>>> prediction_logits = outputs.logits
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
# memories are where token_type_ids == 0
|
|
||||||
memory_mask = token_type_ids == 0
|
|
||||||
# should be shape (num_memories, sequence_length)
|
|
||||||
memories = input_ids[memory_mask]
|
|
||||||
with torch.no_grad():
|
|
||||||
# store memories but we don't back prop
|
|
||||||
self.gpt_neox(memories,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
head_mask=head_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
questions = input_ids[~memory_mask]
|
|
||||||
answers = labels[~memory_mask]
|
|
||||||
|
|
||||||
outputs = self.gpt_neox(
|
outputs = self.gpt_neox(
|
||||||
questions,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@ -776,20 +763,21 @@ class LetheForCausalLM(LethePreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
use_mem_attn=use_mem_attn
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
lm_logits = self.embed_out(hidden_states)
|
lm_logits = self.embed_out(hidden_states)
|
||||||
|
|
||||||
lm_loss = None
|
lm_loss = None
|
||||||
if answers is not None:
|
if labels is not None:
|
||||||
# move labels to correct device to enable model parallelism
|
# move labels to correct device to enable model parallelism
|
||||||
answers = answers.to(lm_logits.device)
|
labels = labels.to(lm_logits.device)
|
||||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
answers = answers[:, 1:].contiguous()
|
labels = labels[:, 1:].contiguous()
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), answers.view(-1))
|
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + outputs[1:]
|
output = (lm_logits,) + outputs[1:]
|
||||||
|
Loading…
Reference in New Issue
Block a user