feat: mem attn forward pass on cuda

This commit is contained in:
Zach Nussbaum 2023-05-18 19:44:04 +00:00
parent 18b04347f5
commit 5f28e60a9a

View File

@ -16,14 +16,11 @@
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import math
import torch.nn.functional as F
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import AutoModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -31,7 +28,6 @@ from transformers.modeling_outputs import (
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from gpt4all.models.lethe import LetheConfig from gpt4all.models.lethe import LetheConfig
import hnswlib import hnswlib
import numpy as np import numpy as np
@ -44,10 +40,9 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
"EleutherAI/gpt-neox-20b", "EleutherAI/gpt-neox-20b",
] ]
# TODO: understand why Phil only does this per batch and doens't persist across many batches # TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
# TODO: k/v are stored per head and per token!!! # TODO: do we need to implement masking for the dense vectors we pull from?
# reshape query, key, value into (bs * seq_len, num_attention_heads, head_size) # TODO: i think phil is using a memmapped database to pull out rather than using the index
# for each head, store index of k/v for each token
class HNSWIndex: class HNSWIndex:
@ -110,7 +105,6 @@ class MemoryIndex:
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3]) reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
for head in range(self.nheads): for head in range(self.nheads):
print(f"adding head {head}")
self.key_indices[head].add(reshaped_keys[:, head, :]) self.key_indices[head].add(reshaped_keys[:, head, :])
self.value_indices[head].add(reshaped_values[:, head, :]) self.value_indices[head].add(reshaped_values[:, head, :])
@ -220,13 +214,14 @@ class LetheAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True,
): ):
has_layer_past = layer_past is not None has_layer_past = layer_past is not None
# Compute QKV # Compute QKV
# Attention heads [batch, seq_len, hidden_size] # Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)] # --> [batch, seq_len, (np * 3 * head_size)]
bs, seq_len, hidden_size = hidden_states.size() _, seq_len, _ = hidden_states.size()
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)] # [batch, seq_len, (num_heads * 3 * head_size)]
@ -265,16 +260,24 @@ class LetheAttention(nn.Module):
# Compute attention # Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# TODO: need to do masking??
if self.memory: if self.memory:
# get knns # get knns
# since we do an eval batch w context before, let's not do the expensive step until we need to
# [batch, knn, num_attention_heads, seq_len, head_size] # [batch, knn, num_attention_heads, seq_len, head_size]
knn_keys, knn_values = self.index.knn_query(query.detach().numpy(), k=self.num_neighbors) if use_mem_attn:
mem_attn = self._mem_attn(query, knn_keys, knn_values, attention_mask, head_mask) knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
mem_attn = self._mem_attn(query,
knn_keys.to(query.device),
knn_values.to(query.device),
attention_mask,
head_mask
)
expanded_alpha = self.alpha[None, :, None, None] expanded_alpha = self.alpha[None, :, None, None]
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha) attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
self.index.add(key.detach().numpy(), value.detach().numpy()) self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
# Reshape outputs # Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
@ -462,6 +465,7 @@ class LetheLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True,
): ):
ln_hidden_states = self.input_layernorm(hidden_states) ln_hidden_states = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention( attention_layer_outputs = self.attention(
@ -472,6 +476,7 @@ class LetheLayer(nn.Module):
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
use_mem_attn=use_mem_attn,
) )
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:] outputs = attention_layer_outputs[1:]
@ -533,6 +538,7 @@ class LetheModel(LethePreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
r""" r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
@ -625,7 +631,7 @@ class LetheModel(LethePreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for layer_past # None for layer_past
return module(*inputs, use_cache, None, output_attentions) return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
return custom_forward return custom_forward
@ -645,6 +651,7 @@ class LetheModel(LethePreTrainedModel):
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
use_mem_attn=use_mem_attn,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
@ -692,7 +699,6 @@ class LetheForCausalLM(LethePreTrainedModel):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
token_type_ids: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@ -703,6 +709,7 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
@ -744,29 +751,9 @@ class LetheForCausalLM(LethePreTrainedModel):
>>> prediction_logits = outputs.logits >>> prediction_logits = outputs.logits
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# memories are where token_type_ids == 0
memory_mask = token_type_ids == 0
# should be shape (num_memories, sequence_length)
memories = input_ids[memory_mask]
with torch.no_grad():
# store memories but we don't back prop
self.gpt_neox(memories,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
questions = input_ids[~memory_mask]
answers = labels[~memory_mask]
outputs = self.gpt_neox( outputs = self.gpt_neox(
questions, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
@ -776,20 +763,21 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
use_mem_attn=use_mem_attn
) )
hidden_states = outputs[0] hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states) lm_logits = self.embed_out(hidden_states)
lm_loss = None lm_loss = None
if answers is not None: if labels is not None:
# move labels to correct device to enable model parallelism # move labels to correct device to enable model parallelism
answers = answers.to(lm_logits.device) labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one # we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous() shift_logits = lm_logits[:, :-1, :].contiguous()
answers = answers[:, 1:].contiguous() labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), answers.view(-1)) lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]