feat: mem attn forward pass on cuda

This commit is contained in:
Zach Nussbaum 2023-05-18 19:44:04 +00:00
parent 18b04347f5
commit 5f28e60a9a

View File

@ -16,14 +16,11 @@
from typing import Optional, Tuple, Union
import math
import torch.nn.functional as F
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
@ -31,7 +28,6 @@ from transformers.modeling_outputs import (
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from gpt4all.models.lethe import LetheConfig
import hnswlib
import numpy as np
@ -44,10 +40,9 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
"EleutherAI/gpt-neox-20b",
]
# TODO: understand why Phil only does this per batch and doens't persist across many batches
# TODO: k/v are stored per head and per token!!!
# reshape query, key, value into (bs * seq_len, num_attention_heads, head_size)
# for each head, store index of k/v for each token
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
# TODO: do we need to implement masking for the dense vectors we pull from?
# TODO: i think phil is using a memmapped database to pull out rather than using the index
class HNSWIndex:
@ -110,7 +105,6 @@ class MemoryIndex:
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
for head in range(self.nheads):
print(f"adding head {head}")
self.key_indices[head].add(reshaped_keys[:, head, :])
self.value_indices[head].add(reshaped_values[:, head, :])
@ -220,13 +214,14 @@ class LetheAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
bs, seq_len, hidden_size = hidden_states.size()
_, seq_len, _ = hidden_states.size()
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
@ -265,16 +260,24 @@ class LetheAttention(nn.Module):
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# TODO: need to do masking??
if self.memory:
# get knns
# since we do an eval batch w context before, let's not do the expensive step until we need to
# [batch, knn, num_attention_heads, seq_len, head_size]
knn_keys, knn_values = self.index.knn_query(query.detach().numpy(), k=self.num_neighbors)
mem_attn = self._mem_attn(query, knn_keys, knn_values, attention_mask, head_mask)
if use_mem_attn:
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
mem_attn = self._mem_attn(query,
knn_keys.to(query.device),
knn_values.to(query.device),
attention_mask,
head_mask
)
expanded_alpha = self.alpha[None, :, None, None]
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
expanded_alpha = self.alpha[None, :, None, None]
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
self.index.add(key.detach().numpy(), value.detach().numpy())
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
@ -462,6 +465,7 @@ class LetheLayer(nn.Module):
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True,
):
ln_hidden_states = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention(
@ -472,6 +476,7 @@ class LetheLayer(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
use_mem_attn=use_mem_attn,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
@ -533,6 +538,7 @@ class LetheModel(LethePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
@ -625,7 +631,7 @@ class LetheModel(LethePreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
return custom_forward
@ -645,6 +651,7 @@ class LetheModel(LethePreTrainedModel):
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
use_mem_attn=use_mem_attn,
)
hidden_states = outputs[0]
if use_cache is True:
@ -692,7 +699,6 @@ class LetheForCausalLM(LethePreTrainedModel):
def forward(
self,
input_ids: torch.LongTensor,
token_type_ids: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -703,6 +709,7 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
@ -744,29 +751,9 @@ class LetheForCausalLM(LethePreTrainedModel):
>>> prediction_logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# memories are where token_type_ids == 0
memory_mask = token_type_ids == 0
# should be shape (num_memories, sequence_length)
memories = input_ids[memory_mask]
with torch.no_grad():
# store memories but we don't back prop
self.gpt_neox(memories,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
questions = input_ids[~memory_mask]
answers = labels[~memory_mask]
outputs = self.gpt_neox(
questions,
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
@ -776,20 +763,21 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_mem_attn=use_mem_attn
)
hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states)
lm_loss = None
if answers is not None:
if labels is not None:
# move labels to correct device to enable model parallelism
answers = answers.to(lm_logits.device)
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
answers = answers[:, 1:].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), answers.view(-1))
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]