fix: wip mem xf

This commit is contained in:
Zach Nussbaum 2023-06-02 23:02:25 +00:00
parent 3677935ce8
commit 55fef489ad
15 changed files with 860 additions and 133 deletions

View File

@ -0,0 +1,23 @@
# model/tokenizer
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: false
memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
# dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
max_length: 1024
batch_size: 1
pct_test: 0.05
q_column: "question"
a_column: "answer"
context_column: "text"
num_memories_per_index: 2000000
num_neighbors_to_retrieve: 2
num_neighbors_to_store: 1
mem_chunk_size: 64

View File

@ -0,0 +1,19 @@
model_name: "nomic-ai/gpt4all-j"
revision: "v1.3-groovy"
tokenizer_name: "nomic-ai/gpt4all-j"
# dataset
streaming: false
num_proc: 64
dataset_path: "nomic-ai/cohere-wiki-sbert"
batch_size: 32
output_path: "synth_qa_pairs"
# generation
max_new_tokens: 75
max_generations: 200000
save_every: 1000
seed: 42

View File

@ -10,35 +10,36 @@ memory_attn_layer: 12
# dataset # dataset
streaming: false streaming: false
num_proc: 64 num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train" dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
max_length: 1024 max_length: 1024
batch_size: 64 batch_size: 32
pct_test: 0.05 pct_test: 0.05
q_column: "question" q_column: "question"
a_column: "answers" a_column: "answer"
context_column: "neighbor_text" context_column: "text"
num_memories_per_index: 500000 num_memories_per_index: 2000000
num_neighbors_to_retrieve: 32 num_neighbors_to_retrieve: 2
num_neighbors_to_store: 1
mem_chunk_size: 64 mem_chunk_size: 64
# train dynamics # train dynamics
lr: 1.0e-4 lr: 1.0e-5
min_lr: 0 min_lr: 0
weight_decay: 0.0 weight_decay: 0.0
eval_every: 100 eval_every: 100
save_every: -1 save_every: 100
log_grads_every: 100 log_grads_every: 100
log_lr_every: 10 log_lr_every: 10
output_dir: "ckpts/mem_attn" output_dir: "ckpts/mem_attn_no_cosine_sim"
checkpoint: null checkpoint: null
lora: false lora: false
warmup_steps: 500 warmup_steps: 200
num_epochs: 5 num_epochs: 5
debug: false debug: false
scheduler: false scheduler: false
# logging # logging
wandb: false wandb: true
wandb_entity: gpt4all wandb_entity: gpt4all
wandb_project_name: mem_attn wandb_project_name: mem_attn
seed: 42 seed: 42

View File

@ -0,0 +1,43 @@
# model/tokenizer
model_name: "EleutherAI/pythia-1b"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/minipille"
push_to_hub: false
memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "JeanKaddour/minipile"
max_length: 2048
batch_size: 64
pct_test: 0.05
num_memories_per_index: 5000000
mem_chunk_size: 512
num_chunks: 10
num_neighbors_to_retrieve: 32
# train dynamics
lr: 1.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 100
save_every: -1
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/minipile"
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 5
debug: false
scheduler: false
# logging
wandb: true
wandb_entity: gpt4all
wandb_project_name: minipile
seed: 42

View File

@ -85,6 +85,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
question_col = config["q_column"] question_col = config["q_column"]
answer_col = config["a_column"] answer_col = config["a_column"]
context_col = config["context_column"]
if config["streaming"] is False: if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]} kwargs = {"num_proc": config["num_proc"]}
@ -96,7 +97,8 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True) dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
# in squad, the data is formatted where each ele in answers is a dict where the key text holds # in squad, the data is formatted where each ele in answers is a dict where the key text holds
# a list of the answer # a list of the answer
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True) dataset = dataset.map(lambda ele: {answer_col: [t.strip() for t in ele[answer_col]]}, batched=True)
# dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
dataset = dataset.map( dataset = dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col), lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
@ -106,19 +108,73 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
# tokenize contexts for each example # tokenize contexts for each example
dataset = dataset.map( dataset = dataset.map(
lambda ele: {"retrieved_context": tokenizer(ele["context"], lambda ele: {"retrieved_context": tokenizer([ele[context_col]],
return_tensors="pt", return_tensors="pt",
padding="max_length", padding="max_length",
truncation=True)["input_ids"]}, truncation=True)["input_ids"]},
batched=True,
**kwargs **kwargs
) )
columns_to_keep = ["input_ids", "labels", "retrieved_context"] columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm) dataset = dataset.remove_columns(col_names_to_rm)
if split_dataset:
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataloader = DataLoader(
train_dataset.remove_columns("id"),
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return train_dataloader, val_dataloader
else:
dataloader = DataLoader(
dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return dataloader
def load_memory_pretraining_data(config, tokenizer, split="train", split_dataset=True):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
dataset = Dataset.load_from_disk(dataset_path)
else:
dataset = load_dataset(dataset_path, split=split)
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# e.g. 512 * 10 = 5120 sequence length split up
max_length = config["mem_chunk_size"] * config["num_chunks"]
dataset = dataset.map(lambda ele: tokenizer(ele["text"], padding="max_length", truncation=True, max_length=max_length),
batched=True, **kwargs)
dataset = dataset.map(lambda x: {"labels": x["input_ids"]}, batched=True, **kwargs)
columns_to_keep = ["input_ids", "labels", "attention_mask"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
# we can shuffle since the docs are in one row not split across rows
if split_dataset: if split_dataset:
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"]) dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"] train_dataset, val_dataset = dataset["train"], dataset["test"]

View File

@ -0,0 +1,116 @@
import torch
import torch.nn.functional as F
from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.utils.read import read_config
from transformers import AutoTokenizer, AutoConfig
from argparse import ArgumentParser
from tqdm import tqdm
from nomic import atlas
from datasets import load_from_disk
def calc_loss_per_item(logits, labels):
lm_logits = logits[:, :-1, :].contiguous()
lm_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
# return tensor of shape (B,) where B is the batch size
return loss.cpu().tolist()
def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
num_new_tokens = 0
with torch.no_grad():
while True:
if num_new_tokens >= max_new_tokens:
break
outputs = model(input_ids, save_kv=False)
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
num_new_tokens += 1
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
break
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
return input_ids
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"], model_max_length=config["max_length"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
dataloader = load_memory_augmented_data(config, tokenizer, split_dataset=False)
dataset = load_from_disk(config["dataset_path"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size,
config["num_memories_per_index"],
model_config.num_attention_heads
)
model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index,
).to(device)
model.eval()
# Evaluate the model on the SQUAD dataset
losses = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader)):
memories = batch["retrieved_context"]
memories = memories[:, :config["num_neighbors_to_store"], :]
memories = memories.reshape(-1, memories.shape[-1])
# need to set to eval so we don't do mem attn as it's slow
with torch.no_grad():
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk.to(device))
del memories
torch.cuda.empty_cache()
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
for i in range(qa_inputs.shape[0]):
inputs = qa_inputs[i].to(device)
labels = qa_labels[i].to(device)
cutoff = torch.argmax((labels != -100).type(torch.float32))
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
print(tokenizer.decode(inputs, skip_special_tokens=True))
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
# losses.extend(batch_loss)
index.reset()
dataset = dataset.add_column("loss", losses)
dataset.save_to_disk("eval_squad_atlas_map")

View File

@ -0,0 +1,57 @@
import glob
from argparse import ArgumentParser
from datasets import Dataset, load_from_disk, concatenate_datasets
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
def load_synth_data(data_dir):
files = glob.glob(data_dir + "/*")
ds = concatenate_datasets([load_from_disk(f) for f in files])
table = ds.data.table
filtered = table.filter(table["valid"])
ds = Dataset.from_dict(filtered.to_pydict())
return ds
def remove_prompt(examples):
outputs = {"text": [], "generated": [], "question": [], "answer": []}
for context, generated in zip(examples["text"], examples["generated"]):
prompt_w_ctx = PROMPT.format(context)
gen_wo_ctx = generated[len(prompt_w_ctx):]
assert prompt_w_ctx not in gen_wo_ctx
question = gen_wo_ctx.split("Answer:")[0].replace("Question:", "").strip()
answer = gen_wo_ctx.split("Answer:")[1].strip()
outputs["text"].append(context)
outputs["generated"].append(gen_wo_ctx)
outputs["question"].append(question)
outputs["answer"].append(answer)
return outputs
def combine_synth_data(data_dir):
ds = load_synth_data(data_dir)
ds = ds.map(lambda ele: remove_prompt(ele), batched=True, num_proc=64)
ds.save_to_disk(f"synth_data_combined_{len(ds)/1000:.0f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--dataset_dir", type=str, required=True)
args = parser.parse_args()
combine_synth_data(args.dataset_dir)

View File

@ -0,0 +1,149 @@
from argparse import ArgumentParser
from datasets import load_dataset, Dataset
import torch
from torch.utils.data import DataLoader, DistributedSampler
from accelerate.utils import set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM, DefaultDataCollator
from gpt4all.utils.read import read_config
from gpt4all.utils.distributed_utils import rank0_print, main_process_first
from tqdm import tqdm
import pyarrow.compute as pc
import pyarrow as pa
import torch.distributed as dist
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
def prepare_data(config, tokenizer, num_processes, local_rank):
dataset = load_dataset(config["dataset_path"], split="train")
dataset = dataset.remove_columns("embedding")
shuffled = dataset.shuffle(seed=config["seed"])
indices = shuffled[:config["max_generations"]]["id"]
table = dataset.data
mask = pc.is_in(table["id"], value_set=pa.array(indices, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
orig_dataset = Dataset.from_dict(filtered_table.to_pydict())
dataset = orig_dataset.map(lambda ele: {"prompted_text": [PROMPT.format(context) for context in ele["text"]]},
batched=True,
num_proc=config["num_proc"] if "num_proc" in config else None)
dataset = dataset.map(lambda ele: {"prompt_len": [len(prompt) for prompt in ele["prompted_text"]]}, batched=True,
num_proc=config["num_proc"] if "num_proc" in config else None)
dataset = dataset.sort("prompt_len")
dataset = dataset.map(lambda ele: tokenizer(ele["prompted_text"], return_tensors="pt", padding="longest", truncation=True,
max_length=tokenizer.model_max_length - config["max_new_tokens"]), batched=True,
batch_size=num_processes * config["batch_size"],
)
columns_to_keep = ["id", "input_ids", "attention_mask"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
sampler = DistributedSampler(
dataset,
shuffle=False,
drop_last=True,
num_replicas=num_processes,
rank=local_rank,
)
dataloader = DataLoader(
dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
sampler=sampler,
drop_last=True
)
return dataloader, orig_dataset
def generate_data(config):
set_seed(config["seed"])
rank0_print(f"World size: {dist.get_world_size()}")
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
# since we're doing generation, pad left for autoregressive generation
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
dataloader, dataset = prepare_data(config, tokenizer, num_processes, local_rank)
dist.barrier()
print(dataset[:10]["id"])
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
revision=config["revision"] if "revision" in config else None,
use_cache=True,
torch_dtype=torch.bfloat16,)
model.to(f"cuda:{local_rank}")
synth_data = []
valid = []
ids = []
total_valid = 0
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, disable=local_rank != 0)):
# keep this simple for now, can add temperature and other sampling techniques later
generated = model.generate(batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
max_new_tokens=config["max_new_tokens"])
decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
num_valid = ["\nQuestion:" in t and "\nAnswer:" in t for t in decoded]
rank0_print(f"Num valid: {sum(num_valid)/ len(num_valid):.2f}")
total_valid += sum(num_valid)
synth_data.extend(decoded)
valid.extend(num_valid)
ids.extend(batch["id"].tolist())
if i > 0 and i % config["save_every"] == 0:
table = dataset.data.table
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
filtered_table = table.filter(mask)
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
joined = filtered_table.join(chunk_table, "id")
curr_dataset = Dataset.from_dict(joined.to_pydict())
curr_dataset.save_to_disk(f'{config["output_path"]}/chunk_{i}_rank_{local_rank}')
table = dataset.data.table
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
filtered_table = table.filter(mask)
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
joined = filtered_table.join(chunk_table, "id")
full_dataset = Dataset.from_dict(joined.to_pydict())
full_dataset.save_to_disk(f'{config["output_path"]}_{config["max_generations"]}_rank_{local_rank}')
rank0_print(f"Total valid: {total_valid}/{config['max_generations']}")
def main():
dist.init_process_group("nccl")
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
generate_data(config)
if __name__ == "__main__":
# parse arguments by reading in a config
main()

View File

@ -111,6 +111,7 @@ class LetheConfig(PretrainedConfig):
memory_attn_layer=9, memory_attn_layer=9,
num_neighbors_to_retrieve=32, num_neighbors_to_retrieve=32,
num_neighbors_stored=128, num_neighbors_stored=128,
attn_scale_init=20.0,
**kwargs, **kwargs,
): ):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -133,4 +134,5 @@ class LetheConfig(PretrainedConfig):
# index of cross attention layer to add # index of cross attention layer to add
self.memory_attn_layer = memory_attn_layer self.memory_attn_layer = memory_attn_layer
self.num_neighbors_to_retrieve = num_neighbors_to_retrieve self.num_neighbors_to_retrieve = num_neighbors_to_retrieve
self.num_neighbors_stored = num_neighbors_stored self.num_neighbors_stored = num_neighbors_stored
self.attn_scale_init = attn_scale_init

View File

@ -12,8 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch PythiaSeek model.""" """ PyTorch Lethe model."""
import wandb
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@ -29,8 +33,9 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from gpt4all.models.lethe import LetheConfig from gpt4all.models.lethe import LetheConfig
import hnswlib
import numpy as np import numpy as np
import faiss
import faiss.contrib.torch_utils
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -40,17 +45,18 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
"EleutherAI/gpt-neox-20b", "EleutherAI/gpt-neox-20b",
] ]
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
# TODO: do we need to implement masking for the dense vectors we pull from?
# TODO: i think phil is using a memmapped database to pull out rather than using the index
class HNSWIndex: class HNSWIndex:
def __init__(self, max_memories, dimension): def __init__(self, max_memories, dimension):
# num_memories will be batch size * num_neighbors # num_memories will be batch size * num_neighbors
# can memmap this too like # can memmap this too like
self.index = hnswlib.Index(space="l2", dim=dimension) self.index = faiss.IndexHNSWFlat(dimension, 16, faiss.METRIC_INNER_PRODUCT)
self.index.init_index(max_elements=max_memories, ef_construction=50, M=16) # taking params from: https://www.pinecone.io/learn/vector-indexes/#hnsw-implementation
# and https://www.pinecone.io/learn/hnsw/#hnsw-performance
# seems like efConstruction dictates how long the index takes to build
# and efSearch and M (second arg to faiss.Index) dictates how long it takes to search
self.index.hnsw.efConstruction = 16
self.index.hnsw.efSearch = 32
self.max_memories = max_memories self.max_memories = max_memories
self.dimension = dimension self.dimension = dimension
@ -60,45 +66,65 @@ class HNSWIndex:
def query(self, query, k=1): def query(self, query, k=1):
# hack what should we do here? # hack what should we do here?
if self.index.get_current_count() == 0: if self.index.ntotal == 0:
return np.ones((query.shape[0], k, query.shape[1]), dtype=np.float32) return np.ones((query.shape[0], k), dtype=np.int32)
assert query.ndim == 2 _, labels = self.index.search(np.ascontiguousarray(query), k=k)
bs_seq_len, _ = query.shape
labels, _ = self.index.knn_query(query, k=k) return labels
neighbors = torch.tensor(self.index.get_items(labels.reshape(-1)))
neighbors = neighbors.reshape((bs_seq_len, k, query.shape[1]))
assert neighbors.ndim == 3
assert neighbors.shape[0] == bs_seq_len
return neighbors
def add(self, memories): def add(self, memories):
assert memories.ndim == 2 return self.index.add(np.ascontiguousarray(memories))
bs_seq_len, _ = memories.shape
ids = np.arange(self.idx_offset, self.idx_offset + bs_seq_len)
self.index.add_items(memories, ids)
self.idx_offset += bs_seq_len
def reset(self): def reset(self):
self.index = hnswlib.Index(space="l2", dim=self.dimension) self.index.reset()
self.index.init_index(max_elements=self.max_memories, ef_construction=50, M=16)
class NumpyKNNIndex:
def __init__(self, max_memories, dimension):
# num_memories will be batch size * num_neighbors
# can memmap this too like
self.index = np.zeros((max_memories, dimension), dtype=np.float32)
self.max_memories = max_memories
self.dimension = dimension
# if we want to allow for insertion of len(elements) > max_memories
# we need to figure out a way to get the most recent memories
self.idx_offset = 0
def query(self, query, k=1):
# hack what should we do here?
if self.index.sum() == 0:
return np.ones((query.shape[0], k), dtype=np.int32)
dots = query.dot(self.index[:self.idx_offset].T)
labels = np.argsort(dots, axis=1)[:, -k:]
return labels
def add(self, memories):
self.index[self.idx_offset:self.idx_offset + memories.shape[0]] = memories
self.idx_offset += memories.shape[0]
def reset(self):
self.index.reset()
class MemoryIndex: class MemoryIndex:
def __init__(self, hidden_dim, num_mems, nheads): def __init__(self, hidden_dim, num_mems, nheads):
# we store an index for each k/v for each head
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
self.value_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
self.nheads = nheads self.nheads = nheads
# NOTE: we are storing kv pairs, instead indices for both keys and values
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
shape = (num_mems, nheads, 2, hidden_dim)
self.kv_pairs = np.zeros(shape, dtype=np.float32)
self.idx_offset = 0
def add(self, keys, values): def add(self, keys, values):
# k/v are (bs, num_attention_heads, seq_len, head_size) # k/v are (bs, num_attention_heads, seq_len, head_size)
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3]) reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
@ -106,22 +132,30 @@ class MemoryIndex:
for head in range(self.nheads): for head in range(self.nheads):
self.key_indices[head].add(reshaped_keys[:, head, :]) self.key_indices[head].add(reshaped_keys[:, head, :])
self.value_indices[head].add(reshaped_values[:, head, :])
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
raise ValueError("Not enough memory!")
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
self.idx_offset += kv_pairs.shape[0]
def knn_query(self, query, k=1): def knn_query(self, query, k=1):
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3]) reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
mem_keys = [] mem_keys = []
mem_values = [] mem_values = []
mem_indices = []
# this is prob so so slow # we can prob make this better
for head in range(self.nheads): for head in range(self.nheads):
knn_keys = self.key_indices[head].query(reshaped_query[:, head, :], k=k) knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
knn_values = self.value_indices[head].query(reshaped_query[:, head, :], k=k) kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
mem_keys.append(knn_keys) mem_keys.append(kv_pairs[:, :, 0, :])
mem_values.append(knn_values) mem_values.append(kv_pairs[:, :, 1, :])
mem_indices.append(knn_indices)
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1)) mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
# (bs, num_attention_heads, seq_len, k, head_size) # (bs, num_attention_heads, seq_len, k, head_size)
@ -131,13 +165,14 @@ class MemoryIndex:
# (bs, num_attention_heads, seq_len, k, head_size) # (bs, num_attention_heads, seq_len, k, head_size)
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],)) mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
return mem_keys, mem_values return mem_keys, mem_values, np.stack(mem_indices, axis=1)
def reset(self): def reset(self):
for head in range(self.nheads): for head in range(self.nheads):
self.key_indices[head].reset() self.key_indices[head].reset()
self.value_indices[head].reset()
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
class LethePreTrainedModel(PreTrainedModel): class LethePreTrainedModel(PreTrainedModel):
@ -171,7 +206,7 @@ class LethePreTrainedModel(PreTrainedModel):
class LetheAttention(nn.Module): class LetheAttention(nn.Module):
def __init__(self, config, memory_attention=False, index=None): def __init__(self, config, memory_attention=False, index=None, tracker=None):
super().__init__() super().__init__()
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -188,21 +223,24 @@ class LetheAttention(nn.Module):
self.rotary_emb = RotaryEmbedding( self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
) )
self.register_buffer( if not memory_attention:
self.register_buffer(
"norm_factor", "norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()), torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False, persistent=False,
) )
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.memory = False self.memory = False
if memory_attention: if memory_attention:
self.scale = nn.Parameter(torch.ones(self.num_attention_heads, 1, 1) * math.log(config.attn_scale_init))
self.memory = True self.memory = True
self.alpha = nn.Parameter(torch.zeros(self.num_attention_heads))
self.num_neighbors = config.num_neighbors_to_retrieve self.num_neighbors = config.num_neighbors_to_retrieve
# for testing, just using np array since it's easy # for testing, just using np array since it's easy
self.index = index self.index = index
self.tracker = tracker
def forward( def forward(
@ -214,7 +252,9 @@ class LetheAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True, log_attn_scores: Optional[bool] = False,
step: Optional[int] = None,
save_kv: Optional[bool] = True,
): ):
has_layer_past = layer_past is not None has_layer_past = layer_past is not None
@ -234,6 +274,12 @@ class LetheAttention(nn.Module):
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# if self.memory:
if self.memory:
# QKNorm: https://arxiv.org/abs/2010.04245
query = F.normalize(query, dim=-1)
key = F.normalize(key, dim=-1)
# Compute rotary embeddings on rotary_ndims # Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims] query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :] query_pass = query[..., self.rotary_ndims :]
@ -257,27 +303,38 @@ class LetheAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None present = (key, value) if use_cache else None
# Compute attention # memory attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# TODO: need to do masking??
if self.memory: if self.memory:
# get knns if save_kv:
# since we do an eval batch w context before, let's not do the expensive step until we need to self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
# [batch, knn, num_attention_heads, seq_len, head_size]
if use_mem_attn:
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
mem_attn = self._mem_attn(query,
knn_keys.to(query.device),
knn_values.to(query.device),
attention_mask,
head_mask
)
expanded_alpha = self.alpha[None, :, None, None] knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha) if log_attn_scores:
batch_size = query.shape[0]
seq_len = query.shape[-2]
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy()) key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
# calculate the accuracy
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
self.tracker.log({"retrieved_acc": key_acc}, step=step)
attn_output = self._mem_attn(query,
knn_keys.to(query.device).to(value.dtype),
knn_values.to(query.device).to(value.dtype),
key,
value,
attention_mask,
head_mask,
log_attn_scores=log_attn_scores,
step=step,
knn_labels=knn_labels,
)
else:
# Normal self-attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs # Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
@ -315,28 +372,131 @@ class LetheAttention(nn.Module):
return tensor return tensor
def _mem_attn(self, query, key, value, attention_mask=None, head_mask=None): def _mem_attn(self,
query,
knn_key,
knn_value,
local_key,
local_value,
attention_mask=None,
head_mask=None,
log_attn_scores=False,
step=None,
knn_labels=None):
# local self-attention
# q: [bs, num_attention_heads, seq_len, attn_head_size] # q: [bs, num_attention_heads, seq_len, attn_head_size]
# k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size] # k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size]
query_length = query.size(-2)
key_length = local_key.size(-2)
attn_scores = torch.einsum("bhsd, bhsnd-> bshn", query, key) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# attn_scores: [bs, seq_len, num_attention_heads, knn]
attn_scores = attn_scores / self.norm_factor
# softmax over knns local_attn_scores = torch.matmul(query, local_key.transpose(-1, -2))
attn_weights = nn.functional.softmax(attn_scores, dim=-1) scale = self.scale.exp()
attn_weights = attn_weights.to(value.dtype)
local_attn_scores = local_attn_scores * scale
mask_value = torch.finfo(local_attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=local_attn_scores.dtype).to(local_attn_scores.device)
local_attn_scores = torch.where(causal_mask, local_attn_scores, mask_value)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
attn_scores = attn_scores + attention_mask local_attn_scores = local_attn_scores + attention_mask
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
# attn_scores: [bs, seq_len, num_attention_heads, knn]
mem_attn_scores = mem_attn_scores * scale
attn_scores = torch.cat((mem_attn_scores, local_attn_scores), dim=-1)
# softmax over knns
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(local_value.dtype)
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
if log_attn_scores:
# (bs, seq_len, num_attention_heads, knn) probabilities
# curate (x,y) pairs
# where x is attention weight, y is accuracy of retrieved token
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
bin_width = 0.05
# Calculate the number of bins
num_bins = int(1 / bin_width)
# Create empty lists for storing bin probabilities and accuracies
bin_probabilities = []
bin_accuracies = []
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
correct_keys = correct_keys.reshape(-1).tolist()
# Iterate over each bin
for i in range(num_bins):
bin_lower = i * bin_width
bin_upper = (i + 1) * bin_width
# Filter data points within the current bin range
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
# Calculate accuracy for the bin
total = len(bin_x_values)
correct = sum(bin_y_values)
accuracy = correct / total if total > 0 else 0
# Store the probability and accuracy for the bin
bin_probabilities.append((bin_lower + bin_upper) / 2)
bin_accuracies.append(accuracy)
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
if log_attn_scores:
# this def won't work well on multi-gpu machines
num_attention_heads = mem_attn_weights.size(1)
for head in range(num_attention_heads):
mem_attn_score_per_head = mem_attn_weights[:, head].reshape(-1)
mem_flat = mem_attn_score_per_head.clone().detach().cpu()
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
mem_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
plt.title(f"mem_attn_score_{head}")
# set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000))
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
plt.close()
local_attn_scores_per_head = local_attn_weights[:, head].reshape(-1)
local_flat = local_attn_scores_per_head.clone().detach().cpu()
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
local_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(local_hist.tolist(), local_bins.tolist())
plt.title(f"local_attn_score_{head}")
# set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000))
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
plt.close()
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size] # attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
attn_output = torch.einsum("bshn, bhsnd-> bhsd", attn_scores, value) mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
local_attn_output = torch.matmul(local_attn_weights, local_value)
# TODO: do we need flamingo style gating
# of output_gate.tanh * attn_output
attn_output = mem_attn_output + local_attn_output
return attn_output return attn_output
def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _attn(self, query, key, value, attention_mask=None, head_mask=None):
@ -361,9 +521,11 @@ class LetheAttention(nn.Module):
query, query,
key.transpose(1, 2), key.transpose(1, 2),
beta=1.0, beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
) )
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
if self.memory:
attn_scores = attn_scores * self.scale.exp()
mask_value = torch.finfo(attn_scores.dtype).min mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
@ -413,7 +575,7 @@ class RotaryEmbedding(torch.nn.Module):
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :] self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :] self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) return self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype), self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype)
def rotate_half(x): def rotate_half(x):
@ -448,12 +610,12 @@ class LetheMLP(nn.Module):
class LetheLayer(nn.Module): class LetheLayer(nn.Module):
def __init__(self, config, memory_attention=False, index=None): def __init__(self, config, memory_attention=False, index=None, tracker=None):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index) self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
self.mlp = LetheMLP(config) self.mlp = LetheMLP(config)
def forward( def forward(
@ -465,7 +627,9 @@ class LetheLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True, log_attn_scores: Optional[bool] = False,
step: Optional[int] = None,
save_kv: Optional[bool] = True
): ):
ln_hidden_states = self.input_layernorm(hidden_states) ln_hidden_states = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention( attention_layer_outputs = self.attention(
@ -476,7 +640,9 @@ class LetheLayer(nn.Module):
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
use_mem_attn=use_mem_attn, log_attn_scores=log_attn_scores,
step=step,
save_kv=save_kv,
) )
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:] outputs = attention_layer_outputs[1:]
@ -503,7 +669,7 @@ class LetheLayer(nn.Module):
class LetheModel(LethePreTrainedModel): class LetheModel(LethePreTrainedModel):
def __init__(self, config, index): def __init__(self, config, index, tracker=None):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
@ -511,7 +677,8 @@ class LetheModel(LethePreTrainedModel):
self.layers = nn.ModuleList([LetheLayer(config, self.layers = nn.ModuleList([LetheLayer(config,
memory_attention=i+1 == config.memory_attn_layer, memory_attention=i+1 == config.memory_attn_layer,
index=index if i+1 == config.memory_attn_layer else None) index=index if i+1 == config.memory_attn_layer else None,
tracker=tracker)
for i in range(config.num_hidden_layers)]) for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -538,7 +705,9 @@ class LetheModel(LethePreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True, log_attn_scores: Optional[bool] = False,
step: Optional[int] = None,
save_kv: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
r""" r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
@ -631,7 +800,7 @@ class LetheModel(LethePreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for layer_past # None for layer_past
return module(*inputs, use_cache, None, output_attentions, use_mem_attn) return module(*inputs, use_cache, None, output_attentions, log_attn_scores, step, save_kv)
return custom_forward return custom_forward
@ -651,7 +820,9 @@ class LetheModel(LethePreTrainedModel):
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
use_mem_attn=use_mem_attn, log_attn_scores=log_attn_scores,
step=step,
save_kv=save_kv,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
@ -678,10 +849,10 @@ class LetheModel(LethePreTrainedModel):
class LetheForCausalLM(LethePreTrainedModel): class LetheForCausalLM(LethePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, index): def __init__(self, config, index, tracker=None):
super().__init__(config) super().__init__(config)
self.gpt_neox = LetheModel(config, index) self.gpt_neox = LetheModel(config, index, tracker=tracker)
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -709,7 +880,9 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True, log_attn_scores: Optional[bool] = None,
step: Optional[int] = None,
save_kv: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
@ -763,7 +936,9 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
use_mem_attn=use_mem_attn log_attn_scores=log_attn_scores,
step=step,
save_kv=save_kv,
) )
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@ -0,0 +1,21 @@
import time
import numpy as np
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
index = MemoryIndex(256,
575000,
8
)
keys = np.random.randn(32, 8, 1024, 256)
values = np.random.randn(32, 8, 1024, 256)
start = time.time()
index.add(keys, values)
print(f"index.add time: {time.time() - start}")
print(index.key_indices[0].index.ntotal)
queries = np.random.randn(32, 8, 1024, 256)
start = time.time()
index.knn_query(queries, k=32)
print(f"index.knn_query time: {time.time() - start}")

View File

@ -23,8 +23,9 @@ print("loading model")
dimension = config.max_position_embeddings * config.hidden_size dimension = config.max_position_embeddings * config.hidden_size
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
index = MemoryIndex(head_size, index = MemoryIndex(head_size,
500_000, 5_000_000,
config.num_attention_heads # 2 since multi-query attention and storing one each for key and value
config.num_attention_heads,
) )
model = LetheForCausalLM(config, index) model = LetheForCausalLM(config, index)
model.to("cuda:0") model.to("cuda:0")
@ -69,7 +70,7 @@ with torch.no_grad():
for chunk_start in tqdm(range(0, memories.shape[0], 32)): for chunk_start in tqdm(range(0, memories.shape[0], 32)):
chunk_end = min(memories.shape[0], chunk_start + 32) chunk_end = min(memories.shape[0], chunk_start + 32)
mem_chunk = memories[chunk_start:chunk_end].to(model.device) mem_chunk = memories[chunk_start:chunk_end].to(model.device)
model(input_ids=mem_chunk, labels=None) model(input_ids=mem_chunk, labels=None,)
model.train() model.train()

View File

@ -0,0 +1,2 @@
from .configuration_pythia_retro import PythiaRetroConfig
from .modeling_pythia_retro import PythiaRetroForCausalLM

View File

@ -1,4 +1,5 @@
import os import os
import torch.nn.functional as F
from transformers import AutoTokenizer, get_scheduler, AutoConfig from transformers import AutoTokenizer, get_scheduler, AutoConfig
import torch import torch
from torch.optim import AdamW from torch.optim import AdamW
@ -12,6 +13,8 @@ from tqdm import tqdm
from gpt4all.models import LetheForCausalLM from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex from gpt4all.models.lethe.modeling_lethe import MemoryIndex
import wandb import wandb
import pyarrow as pa
from pyarrow import feather
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -22,39 +25,58 @@ def format_metrics(metrics, split, prefix=""):
return log return log
def evaluate(model, config, val_dataloader, main_process=False): def calculate_per_example_loss(logits, labels):
lm_logits = logits[:, :-1, :].contiguous()
lm_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
# return tensor of shape (B,) where B is the batch size
return loss.cpu().tolist()
def evaluate(model, index, config, val_dataloader, main_process=False):
model.eval() model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device) val_loss = MeanMetric(nan_strategy="error").to(model.device)
head_size = model.config.hidden_size // model.config.num_attention_heads ids = []
index = MemoryIndex(head_size, losses = []
config["num_memories_per_index"],
model.config.num_attention_heads
)
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(val_dataloader, disable=not main_process): for batch in tqdm(val_dataloader, disable=not main_process):
batch["id"] = batch["id"].detach().cpu()
memories = batch["retrieved_context"] memories = batch["retrieved_context"]
# need to set to eval so we don't do mem attn as it's slow memories = memories[:, :config["num_neighbors_to_store"], :]
model.eval() memories = memories.reshape(-1, memories.shape[-1])
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"]) chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end] mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk, labels=None, use_mem_attn=False) model(input_ids=mem_chunk)
qa_inputs = batch["input_ids"] qa_inputs = batch["input_ids"]
qa_labels = batch["labels"] qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs, outputs = model(input_ids=qa_inputs,
labels=qa_labels, labels=qa_labels,
) )
del memories
torch.cuda.empty_cache()
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
val_loss.update(loss_values["loss"])
index.reset() index.reset()
val_loss.update(loss_values["loss"])
return val_loss per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
losses.extend(per_example_loss)
ids.extend(batch["id"].tolist())
ids = pa.array(ids)
losses = pa.array(losses)
schema = pa.schema([("loss", pa.float64()), ("id", pa.int32())])
table = pa.Table.from_arrays([losses, ids], schema=schema)
return val_loss, table
def train(accelerator, config): def train(accelerator, config):
@ -72,6 +94,7 @@ def train(accelerator, config):
with accelerator.main_process_first(): with accelerator.main_process_first():
train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer) train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer)
if accelerator.state.deepspeed_plugin is not None: if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps" "gradient_accumulation_steps"
@ -97,6 +120,7 @@ def train(accelerator, config):
memory_attn_layer=config["memory_attn_layer"], memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"], num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index, index=index,
tracker=accelerator.get_tracker("wandb"),
) )
@ -135,16 +159,15 @@ def train(accelerator, config):
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader, scheduler model, optimizer, train_dataloader, val_dataloader, scheduler
) )
scheduler = True use_scheduler = True
else: else:
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader model, optimizer, train_dataloader, val_dataloader
) )
scheduler = False use_scheduler = False
# setup for saving training states in case preemption # setup for saving training states in case preemption
if scheduler: if use_scheduler:
accelerator.register_for_checkpointing(scheduler) accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]: if config["checkpoint"]:
@ -167,24 +190,33 @@ def train(accelerator, config):
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
curr_step = step + epoch * len(train_dataloader) curr_step = step + epoch * len(train_dataloader)
memories = batch["retrieved_context"] memories = batch["retrieved_context"]
memories = memories[:, :config["num_neighbors_to_store"], :]
memories = memories.reshape(-1, memories.shape[-1])
# need to set to eval so we don't do mem attn as it's slow # need to set to eval so we don't do mem attn as it's slow
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + 32) chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end] mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk, labels=None, use_mem_attn=False) model(input_ids=mem_chunk)
del memories
torch.cuda.empty_cache()
model.train() model.train()
qa_inputs = batch["input_ids"] qa_inputs = batch["input_ids"]
qa_labels = batch["labels"] qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs, outputs = model(input_ids=qa_inputs,
labels=qa_labels, labels=qa_labels,
log_attn_scores=True,
step=curr_step,
save_kv=False,
) )
loss = outputs.loss loss = outputs.loss
if config["wandb"]:
accelerator.log({"loss": loss}, step=curr_step)
index.reset()
# gather loss before backprop in case of gradient accumulation # gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
@ -192,6 +224,8 @@ def train(accelerator, config):
loss = loss / gradient_accumulation_steps loss = loss / gradient_accumulation_steps
accelerator.backward(loss) accelerator.backward(loss)
# !! don't reset index until after backwards pass
index.reset()
# get gradient norm of all params # get gradient norm of all params
# log LR in case something weird happens # log LR in case something weird happens
@ -202,15 +236,25 @@ def train(accelerator, config):
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step() optimizer.step()
if scheduler: if use_scheduler:
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0: if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") # accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/step_{step}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, config, val_dataloader, main_process=main_process) val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
local_rank = accelerator.process_index
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
log_train = { log_train = {
"train_loss": train_loss.compute() "train_loss": train_loss.compute()

View File

@ -1,4 +1,5 @@
import torch.distributed as dist import torch.distributed as dist
from contextlib import contextmanager
def rank0_print(msg): def rank0_print(msg):
@ -7,3 +8,20 @@ def rank0_print(msg):
print(msg) print(msg)
else: else:
print(msg) print(msg)
@contextmanager
def main_process_first(is_main):
yield from _goes_first(is_main)
def _goes_first(is_main):
if not is_main:
dist.barrier()
yield
if is_main:
dist.barrier()