From 5f28e60a9a3e4b448a9e534e8046d4f2c841a806 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Thu, 18 May 2023 19:44:04 +0000 Subject: [PATCH] feat: mem attn forward pass on cuda --- gpt4all/models/lethe/modeling_lethe.py | 72 +++++++++++--------------- 1 file changed, 30 insertions(+), 42 deletions(-) diff --git a/gpt4all/models/lethe/modeling_lethe.py b/gpt4all/models/lethe/modeling_lethe.py index 07f7f86d..7bf2c3b0 100644 --- a/gpt4all/models/lethe/modeling_lethe.py +++ b/gpt4all/models/lethe/modeling_lethe.py @@ -16,14 +16,11 @@ from typing import Optional, Tuple, Union -import math -import torch.nn.functional as F import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss -from transformers import AutoModel from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -31,7 +28,6 @@ from transformers.modeling_outputs import ( ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging -from transformers.utils.model_parallel_utils import assert_device_map, get_device_map from gpt4all.models.lethe import LetheConfig import hnswlib import numpy as np @@ -44,10 +40,9 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [ "EleutherAI/gpt-neox-20b", ] -# TODO: understand why Phil only does this per batch and doens't persist across many batches -# TODO: k/v are stored per head and per token!!! -# reshape query, key, value into (bs * seq_len, num_attention_heads, head_size) -# for each head, store index of k/v for each token +# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention +# TODO: do we need to implement masking for the dense vectors we pull from? +# TODO: i think phil is using a memmapped database to pull out rather than using the index class HNSWIndex: @@ -110,7 +105,6 @@ class MemoryIndex: reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3]) for head in range(self.nheads): - print(f"adding head {head}") self.key_indices[head].add(reshaped_keys[:, head, :]) self.value_indices[head].add(reshaped_values[:, head, :]) @@ -220,13 +214,14 @@ class LetheAttention(nn.Module): layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + use_mem_attn: Optional[bool] = True, ): has_layer_past = layer_past is not None # Compute QKV # Attention heads [batch, seq_len, hidden_size] # --> [batch, seq_len, (np * 3 * head_size)] - bs, seq_len, hidden_size = hidden_states.size() + _, seq_len, _ = hidden_states.size() qkv = self.query_key_value(hidden_states) # [batch, seq_len, (num_heads * 3 * head_size)] @@ -265,16 +260,24 @@ class LetheAttention(nn.Module): # Compute attention attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + # TODO: need to do masking?? if self.memory: # get knns + # since we do an eval batch w context before, let's not do the expensive step until we need to # [batch, knn, num_attention_heads, seq_len, head_size] - knn_keys, knn_values = self.index.knn_query(query.detach().numpy(), k=self.num_neighbors) - mem_attn = self._mem_attn(query, knn_keys, knn_values, attention_mask, head_mask) + if use_mem_attn: + knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors) + mem_attn = self._mem_attn(query, + knn_keys.to(query.device), + knn_values.to(query.device), + attention_mask, + head_mask + ) - expanded_alpha = self.alpha[None, :, None, None] - attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha) + expanded_alpha = self.alpha[None, :, None, None] + attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha) - self.index.add(key.detach().numpy(), value.detach().numpy()) + self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy()) # Reshape outputs attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) @@ -462,6 +465,7 @@ class LetheLayer(nn.Module): use_cache: Optional[bool] = False, layer_past: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, + use_mem_attn: Optional[bool] = True, ): ln_hidden_states = self.input_layernorm(hidden_states) attention_layer_outputs = self.attention( @@ -472,6 +476,7 @@ class LetheLayer(nn.Module): head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + use_mem_attn=use_mem_attn, ) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) outputs = attention_layer_outputs[1:] @@ -533,6 +538,7 @@ class LetheModel(LethePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_mem_attn: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: r""" past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): @@ -625,7 +631,7 @@ class LetheModel(LethePreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for layer_past - return module(*inputs, use_cache, None, output_attentions) + return module(*inputs, use_cache, None, output_attentions, use_mem_attn) return custom_forward @@ -645,6 +651,7 @@ class LetheModel(LethePreTrainedModel): layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions, + use_mem_attn=use_mem_attn, ) hidden_states = outputs[0] if use_cache is True: @@ -692,7 +699,6 @@ class LetheForCausalLM(LethePreTrainedModel): def forward( self, input_ids: torch.LongTensor, - token_type_ids: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, @@ -703,6 +709,7 @@ class LetheForCausalLM(LethePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_mem_attn: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): @@ -744,29 +751,9 @@ class LetheForCausalLM(LethePreTrainedModel): >>> prediction_logits = outputs.logits ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # memories are where token_type_ids == 0 - memory_mask = token_type_ids == 0 - # should be shape (num_memories, sequence_length) - memories = input_ids[memory_mask] - with torch.no_grad(): - # store memories but we don't back prop - self.gpt_neox(memories, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - questions = input_ids[~memory_mask] - answers = labels[~memory_mask] outputs = self.gpt_neox( - questions, + input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, @@ -776,20 +763,21 @@ class LetheForCausalLM(LethePreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_mem_attn=use_mem_attn ) hidden_states = outputs[0] lm_logits = self.embed_out(hidden_states) lm_loss = None - if answers is not None: + if labels is not None: # move labels to correct device to enable model parallelism - answers = answers.to(lm_logits.device) + labels = labels.to(lm_logits.device) # we are doing next-token prediction; shift prediction scores and input ids by one shift_logits = lm_logits[:, :-1, :].contiguous() - answers = answers[:, 1:].contiguous() + labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), answers.view(-1)) + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) if not return_dict: output = (lm_logits,) + outputs[1:]