From 55fef489ad47985818fa38db797435112a2ff989 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 2 Jun 2023 23:02:25 +0000 Subject: [PATCH] fix: wip mem xf --- configs/eval/evaluate_lethe.yaml | 23 ++ configs/inference/synth_data.yaml | 19 ++ configs/train/finetune_memory.yaml | 23 +- configs/train/pretrain_minipile.yaml | 43 +++ gpt4all/data/retrieval_dataloader.py | 64 +++- gpt4all/eval/eval_squad_atlas_map.py | 116 +++++++ gpt4all/inference/combine_synth_data.py | 57 ++++ gpt4all/inference/generate_synth_data.py | 149 ++++++++ gpt4all/models/lethe/configuration_lethe.py | 4 +- gpt4all/models/lethe/modeling_lethe.py | 357 +++++++++++++++----- gpt4all/models/lethe/test_index.py | 21 ++ gpt4all/models/lethe/test_lethe.py | 7 +- gpt4all/models/pythia_retro/__init__.py | 2 + gpt4all/train/train_mem_retrieval.py | 90 +++-- gpt4all/utils/distributed_utils.py | 18 + 15 files changed, 860 insertions(+), 133 deletions(-) create mode 100644 configs/eval/evaluate_lethe.yaml create mode 100644 configs/inference/synth_data.yaml create mode 100644 configs/train/pretrain_minipile.yaml create mode 100644 gpt4all/eval/eval_squad_atlas_map.py create mode 100644 gpt4all/inference/combine_synth_data.py create mode 100644 gpt4all/inference/generate_synth_data.py create mode 100644 gpt4all/models/lethe/test_index.py create mode 100644 gpt4all/models/pythia_retro/__init__.py diff --git a/configs/eval/evaluate_lethe.yaml b/configs/eval/evaluate_lethe.yaml new file mode 100644 index 00000000..79c54ea6 --- /dev/null +++ b/configs/eval/evaluate_lethe.yaml @@ -0,0 +1,23 @@ +# model/tokenizer +model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000" +tokenizer_name: "EleutherAI/pythia-1b" +version: null +gradient_checkpointing: false +memory_attn_layer: 12 + + +# dataset +streaming: false +num_proc: 64 +dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174" +# dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation" +max_length: 1024 +batch_size: 1 +pct_test: 0.05 +q_column: "question" +a_column: "answer" +context_column: "text" +num_memories_per_index: 2000000 +num_neighbors_to_retrieve: 2 +num_neighbors_to_store: 1 +mem_chunk_size: 64 \ No newline at end of file diff --git a/configs/inference/synth_data.yaml b/configs/inference/synth_data.yaml new file mode 100644 index 00000000..7be76459 --- /dev/null +++ b/configs/inference/synth_data.yaml @@ -0,0 +1,19 @@ +model_name: "nomic-ai/gpt4all-j" +revision: "v1.3-groovy" +tokenizer_name: "nomic-ai/gpt4all-j" + +# dataset +streaming: false +num_proc: 64 +dataset_path: "nomic-ai/cohere-wiki-sbert" +batch_size: 32 +output_path: "synth_qa_pairs" + +# generation +max_new_tokens: 75 +max_generations: 200000 + +save_every: 1000 + + +seed: 42 \ No newline at end of file diff --git a/configs/train/finetune_memory.yaml b/configs/train/finetune_memory.yaml index 9cff8058..791c065c 100644 --- a/configs/train/finetune_memory.yaml +++ b/configs/train/finetune_memory.yaml @@ -10,35 +10,36 @@ memory_attn_layer: 12 # dataset streaming: false num_proc: 64 -dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train" +dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174" max_length: 1024 -batch_size: 64 +batch_size: 32 pct_test: 0.05 q_column: "question" -a_column: "answers" -context_column: "neighbor_text" -num_memories_per_index: 500000 -num_neighbors_to_retrieve: 32 +a_column: "answer" +context_column: "text" +num_memories_per_index: 2000000 +num_neighbors_to_retrieve: 2 +num_neighbors_to_store: 1 mem_chunk_size: 64 # train dynamics -lr: 1.0e-4 +lr: 1.0e-5 min_lr: 0 weight_decay: 0.0 eval_every: 100 -save_every: -1 +save_every: 100 log_grads_every: 100 log_lr_every: 10 -output_dir: "ckpts/mem_attn" +output_dir: "ckpts/mem_attn_no_cosine_sim" checkpoint: null lora: false -warmup_steps: 500 +warmup_steps: 200 num_epochs: 5 debug: false scheduler: false # logging -wandb: false +wandb: true wandb_entity: gpt4all wandb_project_name: mem_attn seed: 42 diff --git a/configs/train/pretrain_minipile.yaml b/configs/train/pretrain_minipile.yaml new file mode 100644 index 00000000..94b72bec --- /dev/null +++ b/configs/train/pretrain_minipile.yaml @@ -0,0 +1,43 @@ +# model/tokenizer +model_name: "EleutherAI/pythia-1b" +tokenizer_name: "EleutherAI/pythia-1b" +version: null +gradient_checkpointing: true +save_name: "nomic-ai/minipille" +push_to_hub: false +memory_attn_layer: 12 + +# dataset +streaming: false +num_proc: 64 +dataset_path: "JeanKaddour/minipile" +max_length: 2048 +batch_size: 64 +pct_test: 0.05 +num_memories_per_index: 5000000 +mem_chunk_size: 512 +num_chunks: 10 +num_neighbors_to_retrieve: 32 + +# train dynamics +lr: 1.0e-4 +min_lr: 0 +weight_decay: 0.0 +eval_every: 100 +save_every: -1 +log_grads_every: 100 +log_lr_every: 10 +output_dir: "ckpts/minipile" +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 5 +debug: false +scheduler: false + +# logging +wandb: true +wandb_entity: gpt4all +wandb_project_name: minipile +seed: 42 + diff --git a/gpt4all/data/retrieval_dataloader.py b/gpt4all/data/retrieval_dataloader.py index 6a04737c..3a4dbacb 100644 --- a/gpt4all/data/retrieval_dataloader.py +++ b/gpt4all/data/retrieval_dataloader.py @@ -85,6 +85,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T question_col = config["q_column"] answer_col = config["a_column"] + context_col = config["context_column"] if config["streaming"] is False: kwargs = {"num_proc": config["num_proc"]} @@ -96,7 +97,8 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True) # in squad, the data is formatted where each ele in answers is a dict where the key text holds # a list of the answer - dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True) + dataset = dataset.map(lambda ele: {answer_col: [t.strip() for t in ele[answer_col]]}, batched=True) + # dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True) dataset = dataset.map( lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col), @@ -106,19 +108,73 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T # tokenize contexts for each example dataset = dataset.map( - lambda ele: {"retrieved_context": tokenizer(ele["context"], + lambda ele: {"retrieved_context": tokenizer([ele[context_col]], return_tensors="pt", padding="max_length", truncation=True)["input_ids"]}, - batched=True, **kwargs ) - columns_to_keep = ["input_ids", "labels", "retrieved_context"] + columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"] col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] dataset = dataset.remove_columns(col_names_to_rm) + if split_dataset: + dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"]) + train_dataset, val_dataset = dataset["train"], dataset["test"] + + train_dataloader = DataLoader( + train_dataset.remove_columns("id"), + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return train_dataloader, val_dataloader + + else: + dataloader = DataLoader( + dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return dataloader + + +def load_memory_pretraining_data(config, tokenizer, split="train", split_dataset=True): + dataset_path = config["dataset_path"] + + if os.path.exists(dataset_path): + dataset = Dataset.load_from_disk(dataset_path) + else: + dataset = load_dataset(dataset_path, split=split) + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # e.g. 512 * 10 = 5120 sequence length split up + max_length = config["mem_chunk_size"] * config["num_chunks"] + dataset = dataset.map(lambda ele: tokenizer(ele["text"], padding="max_length", truncation=True, max_length=max_length), + batched=True, **kwargs) + + dataset = dataset.map(lambda x: {"labels": x["input_ids"]}, batched=True, **kwargs) + + + columns_to_keep = ["input_ids", "labels", "attention_mask"] + + col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] + dataset = dataset.remove_columns(col_names_to_rm) + + # we can shuffle since the docs are in one row not split across rows if split_dataset: dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"]) train_dataset, val_dataset = dataset["train"], dataset["test"] diff --git a/gpt4all/eval/eval_squad_atlas_map.py b/gpt4all/eval/eval_squad_atlas_map.py new file mode 100644 index 00000000..9480ec28 --- /dev/null +++ b/gpt4all/eval/eval_squad_atlas_map.py @@ -0,0 +1,116 @@ +import torch +import torch.nn.functional as F +from gpt4all.models import LetheForCausalLM +from gpt4all.models.lethe.modeling_lethe import MemoryIndex +from gpt4all.data.retrieval_dataloader import load_memory_augmented_data +from gpt4all.train.metrics import f1_score, exact_match_score +from gpt4all.utils.read import read_config +from transformers import AutoTokenizer, AutoConfig +from argparse import ArgumentParser +from tqdm import tqdm +from nomic import atlas +from datasets import load_from_disk + + +def calc_loss_per_item(logits, labels): + lm_logits = logits[:, :-1, :].contiguous() + lm_labels = labels[:, 1:].contiguous() + loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none") + loss = loss.reshape(labels.shape[0], -1).mean(dim=-1) + + # return tensor of shape (B,) where B is the batch size + return loss.cpu().tolist() + + +def greedy_search(input_ids, model, tokenizer, max_new_tokens=100): + num_new_tokens = 0 + with torch.no_grad(): + while True: + if num_new_tokens >= max_new_tokens: + break + outputs = model(input_ids, save_kv=False) + + new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1) + + input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1) + num_new_tokens += 1 + + if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)): + break + + print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)) + + return input_ids + + +parser = ArgumentParser() +parser.add_argument("--config", type=str, default="config.yaml") + +args = parser.parse_args() + +config = read_config(args.config) + +tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"], model_max_length=config["max_length"]) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +dataloader = load_memory_augmented_data(config, tokenizer, split_dataset=False) + +dataset = load_from_disk(config["dataset_path"]) + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model_config = AutoConfig.from_pretrained(config["model_name"]) + +head_size = model_config.hidden_size // model_config.num_attention_heads +index = MemoryIndex(head_size, + config["num_memories_per_index"], + model_config.num_attention_heads +) +model = LetheForCausalLM.from_pretrained(config["model_name"], + revision=config['version'] if 'version' in config else None, + memory_attn_layer=config["memory_attn_layer"], + num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"], + index=index, + ).to(device) +model.eval() + +# Evaluate the model on the SQUAD dataset +losses = [] +with torch.no_grad(): + for i, batch in enumerate(tqdm(dataloader)): + memories = batch["retrieved_context"] + memories = memories[:, :config["num_neighbors_to_store"], :] + memories = memories.reshape(-1, memories.shape[-1]) + + # need to set to eval so we don't do mem attn as it's slow + with torch.no_grad(): + for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): + chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"]) + mem_chunk = memories[chunk_start:chunk_end] + model(input_ids=mem_chunk.to(device)) + + del memories + torch.cuda.empty_cache() + qa_inputs = batch["input_ids"] + qa_labels = batch["labels"] + for i in range(qa_inputs.shape[0]): + inputs = qa_inputs[i].to(device) + labels = qa_labels[i].to(device) + cutoff = torch.argmax((labels != -100).type(torch.float32)) + greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer) + print(tokenizer.decode(inputs, skip_special_tokens=True)) + + + # batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device)) + # losses.extend(batch_loss) + index.reset() + + + +dataset = dataset.add_column("loss", losses) + +dataset.save_to_disk("eval_squad_atlas_map") + + + diff --git a/gpt4all/inference/combine_synth_data.py b/gpt4all/inference/combine_synth_data.py new file mode 100644 index 00000000..3425943c --- /dev/null +++ b/gpt4all/inference/combine_synth_data.py @@ -0,0 +1,57 @@ +import glob +from argparse import ArgumentParser +from datasets import Dataset, load_from_disk, concatenate_datasets + + +PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n" + +def load_synth_data(data_dir): + files = glob.glob(data_dir + "/*") + + ds = concatenate_datasets([load_from_disk(f) for f in files]) + table = ds.data.table + + filtered = table.filter(table["valid"]) + ds = Dataset.from_dict(filtered.to_pydict()) + return ds + + +def remove_prompt(examples): + outputs = {"text": [], "generated": [], "question": [], "answer": []} + for context, generated in zip(examples["text"], examples["generated"]): + prompt_w_ctx = PROMPT.format(context) + gen_wo_ctx = generated[len(prompt_w_ctx):] + + assert prompt_w_ctx not in gen_wo_ctx + + question = gen_wo_ctx.split("Answer:")[0].replace("Question:", "").strip() + answer = gen_wo_ctx.split("Answer:")[1].strip() + + outputs["text"].append(context) + outputs["generated"].append(gen_wo_ctx) + outputs["question"].append(question) + outputs["answer"].append(answer) + + return outputs + + + + +def combine_synth_data(data_dir): + ds = load_synth_data(data_dir) + + ds = ds.map(lambda ele: remove_prompt(ele), batched=True, num_proc=64) + + ds.save_to_disk(f"synth_data_combined_{len(ds)/1000:.0f}") + + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--dataset_dir", type=str, required=True) + + args = parser.parse_args() + + combine_synth_data(args.dataset_dir) + + \ No newline at end of file diff --git a/gpt4all/inference/generate_synth_data.py b/gpt4all/inference/generate_synth_data.py new file mode 100644 index 00000000..60c00074 --- /dev/null +++ b/gpt4all/inference/generate_synth_data.py @@ -0,0 +1,149 @@ +from argparse import ArgumentParser +from datasets import load_dataset, Dataset +import torch +from torch.utils.data import DataLoader, DistributedSampler +from accelerate.utils import set_seed +from transformers import AutoTokenizer, AutoModelForCausalLM, DefaultDataCollator +from gpt4all.utils.read import read_config +from gpt4all.utils.distributed_utils import rank0_print, main_process_first +from tqdm import tqdm +import pyarrow.compute as pc +import pyarrow as pa +import torch.distributed as dist + + +PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n" + + +def prepare_data(config, tokenizer, num_processes, local_rank): + dataset = load_dataset(config["dataset_path"], split="train") + dataset = dataset.remove_columns("embedding") + + shuffled = dataset.shuffle(seed=config["seed"]) + indices = shuffled[:config["max_generations"]]["id"] + + table = dataset.data + mask = pc.is_in(table["id"], value_set=pa.array(indices, pa.int32())) + filtered_table = table.filter(mask) + + # convert from pyarrow to Dataset + orig_dataset = Dataset.from_dict(filtered_table.to_pydict()) + + dataset = orig_dataset.map(lambda ele: {"prompted_text": [PROMPT.format(context) for context in ele["text"]]}, + batched=True, + num_proc=config["num_proc"] if "num_proc" in config else None) + + dataset = dataset.map(lambda ele: {"prompt_len": [len(prompt) for prompt in ele["prompted_text"]]}, batched=True, + num_proc=config["num_proc"] if "num_proc" in config else None) + + dataset = dataset.sort("prompt_len") + dataset = dataset.map(lambda ele: tokenizer(ele["prompted_text"], return_tensors="pt", padding="longest", truncation=True, + max_length=tokenizer.model_max_length - config["max_new_tokens"]), batched=True, + batch_size=num_processes * config["batch_size"], + ) + + columns_to_keep = ["id", "input_ids", "attention_mask"] + col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] + + dataset = dataset.remove_columns(col_names_to_rm) + + sampler = DistributedSampler( + dataset, + shuffle=False, + drop_last=True, + num_replicas=num_processes, + rank=local_rank, + ) + + dataloader = DataLoader( + dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + sampler=sampler, + drop_last=True + ) + + return dataloader, orig_dataset + +def generate_data(config): + set_seed(config["seed"]) + + rank0_print(f"World size: {dist.get_world_size()}") + + tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"]) + # since we're doing generation, pad left for autoregressive generation + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + num_processes = dist.get_world_size() + local_rank = dist.get_rank() + + dataloader, dataset = prepare_data(config, tokenizer, num_processes, local_rank) + + dist.barrier() + print(dataset[:10]["id"]) + + model = AutoModelForCausalLM.from_pretrained(config["model_name"], + revision=config["revision"] if "revision" in config else None, + use_cache=True, + torch_dtype=torch.bfloat16,) + model.to(f"cuda:{local_rank}") + + synth_data = [] + valid = [] + ids = [] + total_valid = 0 + + with torch.no_grad(): + for i, batch in enumerate(tqdm(dataloader, disable=local_rank != 0)): + # keep this simple for now, can add temperature and other sampling techniques later + generated = model.generate(batch["input_ids"].to(model.device), + attention_mask=batch["attention_mask"].to(model.device), + max_new_tokens=config["max_new_tokens"]) + + decoded = tokenizer.batch_decode(generated, skip_special_tokens=True) + num_valid = ["\nQuestion:" in t and "\nAnswer:" in t for t in decoded] + rank0_print(f"Num valid: {sum(num_valid)/ len(num_valid):.2f}") + total_valid += sum(num_valid) + + synth_data.extend(decoded) + valid.extend(num_valid) + ids.extend(batch["id"].tolist()) + + if i > 0 and i % config["save_every"] == 0: + table = dataset.data.table + mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32())) + filtered_table = table.filter(mask) + + chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid}) + joined = filtered_table.join(chunk_table, "id") + curr_dataset = Dataset.from_dict(joined.to_pydict()) + curr_dataset.save_to_disk(f'{config["output_path"]}/chunk_{i}_rank_{local_rank}') + + table = dataset.data.table + mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32())) + filtered_table = table.filter(mask) + + chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid}) + joined = filtered_table.join(chunk_table, "id") + full_dataset = Dataset.from_dict(joined.to_pydict()) + full_dataset.save_to_disk(f'{config["output_path"]}_{config["max_generations"]}_rank_{local_rank}') + + rank0_print(f"Total valid: {total_valid}/{config['max_generations']}") + + +def main(): + dist.init_process_group("nccl") + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + config = read_config(args.config) + + generate_data(config) + + +if __name__ == "__main__": + # parse arguments by reading in a config + main() \ No newline at end of file diff --git a/gpt4all/models/lethe/configuration_lethe.py b/gpt4all/models/lethe/configuration_lethe.py index 558362f2..9b4449ff 100644 --- a/gpt4all/models/lethe/configuration_lethe.py +++ b/gpt4all/models/lethe/configuration_lethe.py @@ -111,6 +111,7 @@ class LetheConfig(PretrainedConfig): memory_attn_layer=9, num_neighbors_to_retrieve=32, num_neighbors_stored=128, + attn_scale_init=20.0, **kwargs, ): super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -133,4 +134,5 @@ class LetheConfig(PretrainedConfig): # index of cross attention layer to add self.memory_attn_layer = memory_attn_layer self.num_neighbors_to_retrieve = num_neighbors_to_retrieve - self.num_neighbors_stored = num_neighbors_stored \ No newline at end of file + self.num_neighbors_stored = num_neighbors_stored + self.attn_scale_init = attn_scale_init \ No newline at end of file diff --git a/gpt4all/models/lethe/modeling_lethe.py b/gpt4all/models/lethe/modeling_lethe.py index 7bf2c3b0..80d658af 100644 --- a/gpt4all/models/lethe/modeling_lethe.py +++ b/gpt4all/models/lethe/modeling_lethe.py @@ -12,8 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch PythiaSeek model.""" +""" PyTorch Lethe model.""" +import wandb +import math +import torch.nn.functional as F +import matplotlib.pyplot as plt from typing import Optional, Tuple, Union import torch @@ -29,8 +33,9 @@ from transformers.modeling_outputs import ( from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from gpt4all.models.lethe import LetheConfig -import hnswlib import numpy as np +import faiss +import faiss.contrib.torch_utils logger = logging.get_logger(__name__) @@ -40,17 +45,18 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [ "EleutherAI/gpt-neox-20b", ] -# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention -# TODO: do we need to implement masking for the dense vectors we pull from? -# TODO: i think phil is using a memmapped database to pull out rather than using the index - class HNSWIndex: def __init__(self, max_memories, dimension): # num_memories will be batch size * num_neighbors # can memmap this too like - self.index = hnswlib.Index(space="l2", dim=dimension) - self.index.init_index(max_elements=max_memories, ef_construction=50, M=16) + self.index = faiss.IndexHNSWFlat(dimension, 16, faiss.METRIC_INNER_PRODUCT) + # taking params from: https://www.pinecone.io/learn/vector-indexes/#hnsw-implementation + # and https://www.pinecone.io/learn/hnsw/#hnsw-performance + # seems like efConstruction dictates how long the index takes to build + # and efSearch and M (second arg to faiss.Index) dictates how long it takes to search + self.index.hnsw.efConstruction = 16 + self.index.hnsw.efSearch = 32 self.max_memories = max_memories self.dimension = dimension @@ -60,45 +66,65 @@ class HNSWIndex: def query(self, query, k=1): # hack what should we do here? - if self.index.get_current_count() == 0: - return np.ones((query.shape[0], k, query.shape[1]), dtype=np.float32) + if self.index.ntotal == 0: + return np.ones((query.shape[0], k), dtype=np.int32) - assert query.ndim == 2 - bs_seq_len, _ = query.shape + _, labels = self.index.search(np.ascontiguousarray(query), k=k) - labels, _ = self.index.knn_query(query, k=k) - neighbors = torch.tensor(self.index.get_items(labels.reshape(-1))) - neighbors = neighbors.reshape((bs_seq_len, k, query.shape[1])) - - assert neighbors.ndim == 3 - assert neighbors.shape[0] == bs_seq_len - - return neighbors + return labels def add(self, memories): - assert memories.ndim == 2 - bs_seq_len, _ = memories.shape - - ids = np.arange(self.idx_offset, self.idx_offset + bs_seq_len) - - self.index.add_items(memories, ids) - - self.idx_offset += bs_seq_len + return self.index.add(np.ascontiguousarray(memories)) def reset(self): - self.index = hnswlib.Index(space="l2", dim=self.dimension) - self.index.init_index(max_elements=self.max_memories, ef_construction=50, M=16) + self.index.reset() + +class NumpyKNNIndex: + def __init__(self, max_memories, dimension): + # num_memories will be batch size * num_neighbors + # can memmap this too like + self.index = np.zeros((max_memories, dimension), dtype=np.float32) + self.max_memories = max_memories + self.dimension = dimension + + # if we want to allow for insertion of len(elements) > max_memories + # we need to figure out a way to get the most recent memories + self.idx_offset = 0 + + def query(self, query, k=1): + # hack what should we do here? + if self.index.sum() == 0: + return np.ones((query.shape[0], k), dtype=np.int32) + + dots = query.dot(self.index[:self.idx_offset].T) + labels = np.argsort(dots, axis=1)[:, -k:] + + return labels + + + def add(self, memories): + self.index[self.idx_offset:self.idx_offset + memories.shape[0]] = memories + self.idx_offset += memories.shape[0] + + + def reset(self): + self.index.reset() + class MemoryIndex: def __init__(self, hidden_dim, num_mems, nheads): - # we store an index for each k/v for each head - self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)] - self.value_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)] self.nheads = nheads + # NOTE: we are storing kv pairs, instead indices for both keys and values + self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)] + + shape = (num_mems, nheads, 2, hidden_dim) + self.kv_pairs = np.zeros(shape, dtype=np.float32) + self.idx_offset = 0 + def add(self, keys, values): # k/v are (bs, num_attention_heads, seq_len, head_size) reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3]) @@ -106,22 +132,30 @@ class MemoryIndex: for head in range(self.nheads): self.key_indices[head].add(reshaped_keys[:, head, :]) - self.value_indices[head].add(reshaped_values[:, head, :]) + kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2) + + if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]: + raise ValueError("Not enough memory!") + + self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs + self.idx_offset += kv_pairs.shape[0] def knn_query(self, query, k=1): reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3]) mem_keys = [] mem_values = [] + mem_indices = [] - # this is prob so so slow + # we can prob make this better for head in range(self.nheads): - knn_keys = self.key_indices[head].query(reshaped_query[:, head, :], k=k) - knn_values = self.value_indices[head].query(reshaped_query[:, head, :], k=k) - - mem_keys.append(knn_keys) - mem_values.append(knn_values) + knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k) + kv_pairs = self.kv_pairs[:, head, :, :][knn_indices] + + mem_keys.append(kv_pairs[:, :, 0, :]) + mem_values.append(kv_pairs[:, :, 1, :]) + mem_indices.append(knn_indices) mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1)) # (bs, num_attention_heads, seq_len, k, head_size) @@ -131,13 +165,14 @@ class MemoryIndex: # (bs, num_attention_heads, seq_len, k, head_size) mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],)) - return mem_keys, mem_values + return mem_keys, mem_values, np.stack(mem_indices, axis=1) def reset(self): for head in range(self.nheads): self.key_indices[head].reset() - self.value_indices[head].reset() + + self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32) class LethePreTrainedModel(PreTrainedModel): @@ -171,7 +206,7 @@ class LethePreTrainedModel(PreTrainedModel): class LetheAttention(nn.Module): - def __init__(self, config, memory_attention=False, index=None): + def __init__(self, config, memory_attention=False, index=None, tracker=None): super().__init__() self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -188,21 +223,24 @@ class LetheAttention(nn.Module): self.rotary_emb = RotaryEmbedding( self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base ) - self.register_buffer( + if not memory_attention: + self.register_buffer( "norm_factor", torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()), persistent=False, ) + self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.memory = False if memory_attention: + self.scale = nn.Parameter(torch.ones(self.num_attention_heads, 1, 1) * math.log(config.attn_scale_init)) self.memory = True - self.alpha = nn.Parameter(torch.zeros(self.num_attention_heads)) self.num_neighbors = config.num_neighbors_to_retrieve # for testing, just using np array since it's easy self.index = index + self.tracker = tracker def forward( @@ -214,7 +252,9 @@ class LetheAttention(nn.Module): layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - use_mem_attn: Optional[bool] = True, + log_attn_scores: Optional[bool] = False, + step: Optional[int] = None, + save_kv: Optional[bool] = True, ): has_layer_past = layer_past is not None @@ -234,6 +274,12 @@ class LetheAttention(nn.Module): key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + # if self.memory: + if self.memory: + # QKNorm: https://arxiv.org/abs/2010.04245 + query = F.normalize(query, dim=-1) + key = F.normalize(key, dim=-1) + # Compute rotary embeddings on rotary_ndims query_rot = query[..., : self.rotary_ndims] query_pass = query[..., self.rotary_ndims :] @@ -257,27 +303,38 @@ class LetheAttention(nn.Module): value = torch.cat((past_value, value), dim=-2) present = (key, value) if use_cache else None - # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - # TODO: need to do masking?? + # memory attention if self.memory: - # get knns - # since we do an eval batch w context before, let's not do the expensive step until we need to - # [batch, knn, num_attention_heads, seq_len, head_size] - if use_mem_attn: - knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors) - mem_attn = self._mem_attn(query, - knn_keys.to(query.device), - knn_values.to(query.device), - attention_mask, - head_mask - ) + if save_kv: + self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy()) - expanded_alpha = self.alpha[None, :, None, None] - attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha) + knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors) + if log_attn_scores: + batch_size = query.shape[0] + seq_len = query.shape[-2] - self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy()) + key_labels = knn_labels // seq_len + key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1) + correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis]) + # calculate the accuracy + key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape) + + self.tracker.log({"retrieved_acc": key_acc}, step=step) + + attn_output = self._mem_attn(query, + knn_keys.to(query.device).to(value.dtype), + knn_values.to(query.device).to(value.dtype), + key, + value, + attention_mask, + head_mask, + log_attn_scores=log_attn_scores, + step=step, + knn_labels=knn_labels, + ) + else: + # Normal self-attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) # Reshape outputs attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) @@ -315,28 +372,131 @@ class LetheAttention(nn.Module): return tensor - def _mem_attn(self, query, key, value, attention_mask=None, head_mask=None): + def _mem_attn(self, + query, + knn_key, + knn_value, + local_key, + local_value, + attention_mask=None, + head_mask=None, + log_attn_scores=False, + step=None, + knn_labels=None): + # local self-attention # q: [bs, num_attention_heads, seq_len, attn_head_size] # k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size] + query_length = query.size(-2) + key_length = local_key.size(-2) - attn_scores = torch.einsum("bhsd, bhsnd-> bshn", query, key) - # attn_scores: [bs, seq_len, num_attention_heads, knn] - attn_scores = attn_scores / self.norm_factor + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - # softmax over knns - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) + local_attn_scores = torch.matmul(query, local_key.transpose(-1, -2)) + scale = self.scale.exp() + + local_attn_scores = local_attn_scores * scale + + mask_value = torch.finfo(local_attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=local_attn_scores.dtype).to(local_attn_scores.device) + local_attn_scores = torch.where(causal_mask, local_attn_scores, mask_value) if attention_mask is not None: # Apply the attention mask - attn_scores = attn_scores + attention_mask + local_attn_scores = local_attn_scores + attention_mask + + mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key) + # attn_scores: [bs, seq_len, num_attention_heads, knn] + mem_attn_scores = mem_attn_scores * scale + + attn_scores = torch.cat((mem_attn_scores, local_attn_scores), dim=-1) + + # softmax over knns + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(local_value.dtype) + + mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1) + if log_attn_scores: + # (bs, seq_len, num_attention_heads, knn) probabilities + # curate (x,y) pairs + # where x is attention weight, y is accuracy of retrieved token + bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2) + key_labels = knn_labels // seq_len + key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1) + correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis]) + + bin_width = 0.05 + + # Calculate the number of bins + num_bins = int(1 / bin_width) + + # Create empty lists for storing bin probabilities and accuracies + bin_probabilities = [] + bin_accuracies = [] + + probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist() + correct_keys = correct_keys.reshape(-1).tolist() + + # Iterate over each bin + for i in range(num_bins): + bin_lower = i * bin_width + bin_upper = (i + 1) * bin_width + + # Filter data points within the current bin range + bin_x_values = [x for x in probs if bin_lower <= x < bin_upper] + bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper] + + # Calculate accuracy for the bin + total = len(bin_x_values) + correct = sum(bin_y_values) + accuracy = correct / total if total > 0 else 0 + + # Store the probability and accuracy for the bin + bin_probabilities.append((bin_lower + bin_upper) / 2) + bin_accuracies.append(accuracy) + + data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)] + table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"]) + self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step) + + + if log_attn_scores: + # this def won't work well on multi-gpu machines + num_attention_heads = mem_attn_weights.size(1) + for head in range(num_attention_heads): + mem_attn_score_per_head = mem_attn_weights[:, head].reshape(-1) + mem_flat = mem_attn_score_per_head.clone().detach().cpu() + mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1) + mem_bins = torch.linspace(0, 1, steps=20 + 1) + plt.stairs(mem_hist.tolist(), mem_bins.tolist()) + plt.title(f"mem_attn_score_{head}") + # set arbitrarily but we want to see those peaks!! + plt.ylim((0, 1000)) + self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step) + plt.close() + + + local_attn_scores_per_head = local_attn_weights[:, head].reshape(-1) + local_flat = local_attn_scores_per_head.clone().detach().cpu() + local_hist = torch.histc(local_flat, bins=20, min=0, max=1) + local_bins = torch.linspace(0, 1, steps=20 + 1) + plt.stairs(local_hist.tolist(), local_bins.tolist()) + plt.title(f"local_attn_score_{head}") + # set arbitrarily but we want to see those peaks!! + plt.ylim((0, 1000)) + self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step) + plt.close() - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask # attn_output: [bs, num_attention_heads, seq_len, attn_head_size] - attn_output = torch.einsum("bshn, bhsnd-> bhsd", attn_scores, value) + mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value) + local_attn_output = torch.matmul(local_attn_weights, local_value) + + # TODO: do we need flamingo style gating + # of output_gate.tanh * attn_output + attn_output = mem_attn_output + local_attn_output + return attn_output def _attn(self, query, key, value, attention_mask=None, head_mask=None): @@ -361,9 +521,11 @@ class LetheAttention(nn.Module): query, key.transpose(1, 2), beta=1.0, - alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), + alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor) ) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + if self.memory: + attn_scores = attn_scores * self.scale.exp() mask_value = torch.finfo(attn_scores.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. @@ -413,7 +575,7 @@ class RotaryEmbedding(torch.nn.Module): emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos()[None, None, :, :] self.sin_cached = emb.sin()[None, None, :, :] - return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + return self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype), self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype) def rotate_half(x): @@ -448,12 +610,12 @@ class LetheMLP(nn.Module): class LetheLayer(nn.Module): - def __init__(self, config, memory_attention=False, index=None): + def __init__(self, config, memory_attention=False, index=None, tracker=None): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = LetheAttention(config, memory_attention=memory_attention, index=index) + self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker) self.mlp = LetheMLP(config) def forward( @@ -465,7 +627,9 @@ class LetheLayer(nn.Module): use_cache: Optional[bool] = False, layer_past: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, - use_mem_attn: Optional[bool] = True, + log_attn_scores: Optional[bool] = False, + step: Optional[int] = None, + save_kv: Optional[bool] = True ): ln_hidden_states = self.input_layernorm(hidden_states) attention_layer_outputs = self.attention( @@ -476,7 +640,9 @@ class LetheLayer(nn.Module): head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - use_mem_attn=use_mem_attn, + log_attn_scores=log_attn_scores, + step=step, + save_kv=save_kv, ) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) outputs = attention_layer_outputs[1:] @@ -503,7 +669,7 @@ class LetheLayer(nn.Module): class LetheModel(LethePreTrainedModel): - def __init__(self, config, index): + def __init__(self, config, index, tracker=None): super().__init__(config) self.config = config @@ -511,7 +677,8 @@ class LetheModel(LethePreTrainedModel): self.layers = nn.ModuleList([LetheLayer(config, memory_attention=i+1 == config.memory_attn_layer, - index=index if i+1 == config.memory_attn_layer else None) + index=index if i+1 == config.memory_attn_layer else None, + tracker=tracker) for i in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -538,7 +705,9 @@ class LetheModel(LethePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - use_mem_attn: Optional[bool] = True, + log_attn_scores: Optional[bool] = False, + step: Optional[int] = None, + save_kv: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: r""" past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): @@ -631,7 +800,7 @@ class LetheModel(LethePreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for layer_past - return module(*inputs, use_cache, None, output_attentions, use_mem_attn) + return module(*inputs, use_cache, None, output_attentions, log_attn_scores, step, save_kv) return custom_forward @@ -651,7 +820,9 @@ class LetheModel(LethePreTrainedModel): layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions, - use_mem_attn=use_mem_attn, + log_attn_scores=log_attn_scores, + step=step, + save_kv=save_kv, ) hidden_states = outputs[0] if use_cache is True: @@ -678,10 +849,10 @@ class LetheModel(LethePreTrainedModel): class LetheForCausalLM(LethePreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - def __init__(self, config, index): + def __init__(self, config, index, tracker=None): super().__init__(config) - self.gpt_neox = LetheModel(config, index) + self.gpt_neox = LetheModel(config, index, tracker=tracker) self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.hidden_size = config.hidden_size @@ -709,7 +880,9 @@ class LetheForCausalLM(LethePreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - use_mem_attn: Optional[bool] = True, + log_attn_scores: Optional[bool] = None, + step: Optional[int] = None, + save_kv: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): @@ -763,7 +936,9 @@ class LetheForCausalLM(LethePreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - use_mem_attn=use_mem_attn + log_attn_scores=log_attn_scores, + step=step, + save_kv=save_kv, ) hidden_states = outputs[0] diff --git a/gpt4all/models/lethe/test_index.py b/gpt4all/models/lethe/test_index.py new file mode 100644 index 00000000..2ed5d965 --- /dev/null +++ b/gpt4all/models/lethe/test_index.py @@ -0,0 +1,21 @@ +import time +import numpy as np +from gpt4all.models.lethe.modeling_lethe import MemoryIndex + + +index = MemoryIndex(256, + 575000, + 8 +) + +keys = np.random.randn(32, 8, 1024, 256) +values = np.random.randn(32, 8, 1024, 256) +start = time.time() +index.add(keys, values) +print(f"index.add time: {time.time() - start}") + +print(index.key_indices[0].index.ntotal) +queries = np.random.randn(32, 8, 1024, 256) +start = time.time() +index.knn_query(queries, k=32) +print(f"index.knn_query time: {time.time() - start}") \ No newline at end of file diff --git a/gpt4all/models/lethe/test_lethe.py b/gpt4all/models/lethe/test_lethe.py index 13360319..c548e173 100644 --- a/gpt4all/models/lethe/test_lethe.py +++ b/gpt4all/models/lethe/test_lethe.py @@ -23,8 +23,9 @@ print("loading model") dimension = config.max_position_embeddings * config.hidden_size head_size = config.hidden_size // config.num_attention_heads index = MemoryIndex(head_size, - 500_000, - config.num_attention_heads + 5_000_000, + # 2 since multi-query attention and storing one each for key and value + config.num_attention_heads, ) model = LetheForCausalLM(config, index) model.to("cuda:0") @@ -69,7 +70,7 @@ with torch.no_grad(): for chunk_start in tqdm(range(0, memories.shape[0], 32)): chunk_end = min(memories.shape[0], chunk_start + 32) mem_chunk = memories[chunk_start:chunk_end].to(model.device) - model(input_ids=mem_chunk, labels=None) + model(input_ids=mem_chunk, labels=None,) model.train() diff --git a/gpt4all/models/pythia_retro/__init__.py b/gpt4all/models/pythia_retro/__init__.py new file mode 100644 index 00000000..e58e7069 --- /dev/null +++ b/gpt4all/models/pythia_retro/__init__.py @@ -0,0 +1,2 @@ +from .configuration_pythia_retro import PythiaRetroConfig +from .modeling_pythia_retro import PythiaRetroForCausalLM \ No newline at end of file diff --git a/gpt4all/train/train_mem_retrieval.py b/gpt4all/train/train_mem_retrieval.py index bf56d612..864e1dd7 100644 --- a/gpt4all/train/train_mem_retrieval.py +++ b/gpt4all/train/train_mem_retrieval.py @@ -1,4 +1,5 @@ import os +import torch.nn.functional as F from transformers import AutoTokenizer, get_scheduler, AutoConfig import torch from torch.optim import AdamW @@ -12,6 +13,8 @@ from tqdm import tqdm from gpt4all.models import LetheForCausalLM from gpt4all.models.lethe.modeling_lethe import MemoryIndex import wandb +import pyarrow as pa +from pyarrow import feather torch.backends.cuda.matmul.allow_tf32 = True @@ -22,39 +25,58 @@ def format_metrics(metrics, split, prefix=""): return log -def evaluate(model, config, val_dataloader, main_process=False): +def calculate_per_example_loss(logits, labels): + lm_logits = logits[:, :-1, :].contiguous() + lm_labels = labels[:, 1:].contiguous() + loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none") + loss = loss.reshape(labels.shape[0], -1).mean(dim=-1) + + # return tensor of shape (B,) where B is the batch size + return loss.cpu().tolist() + + +def evaluate(model, index, config, val_dataloader, main_process=False): model.eval() val_loss = MeanMetric(nan_strategy="error").to(model.device) - head_size = model.config.hidden_size // model.config.num_attention_heads - index = MemoryIndex(head_size, - config["num_memories_per_index"], - model.config.num_attention_heads - ) - + ids = [] + losses = [] with torch.no_grad(): for batch in tqdm(val_dataloader, disable=not main_process): + batch["id"] = batch["id"].detach().cpu() memories = batch["retrieved_context"] - # need to set to eval so we don't do mem attn as it's slow - model.eval() + memories = memories[:, :config["num_neighbors_to_store"], :] + memories = memories.reshape(-1, memories.shape[-1]) + for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"]) mem_chunk = memories[chunk_start:chunk_end] - model(input_ids=mem_chunk, labels=None, use_mem_attn=False) + model(input_ids=mem_chunk) qa_inputs = batch["input_ids"] qa_labels = batch["labels"] outputs = model(input_ids=qa_inputs, labels=qa_labels, ) + + del memories + torch.cuda.empty_cache() loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) - - val_loss.update(loss_values["loss"]) index.reset() + val_loss.update(loss_values["loss"]) - return val_loss + per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels) + + losses.extend(per_example_loss) + ids.extend(batch["id"].tolist()) + + ids = pa.array(ids) + losses = pa.array(losses) + schema = pa.schema([("loss", pa.float64()), ("id", pa.int32())]) + table = pa.Table.from_arrays([losses, ids], schema=schema) + return val_loss, table def train(accelerator, config): @@ -72,6 +94,7 @@ def train(accelerator, config): with accelerator.main_process_first(): train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer) + if accelerator.state.deepspeed_plugin is not None: gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ "gradient_accumulation_steps" @@ -97,6 +120,7 @@ def train(accelerator, config): memory_attn_layer=config["memory_attn_layer"], num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"], index=index, + tracker=accelerator.get_tracker("wandb"), ) @@ -135,16 +159,15 @@ def train(accelerator, config): model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader, scheduler ) - scheduler = True + use_scheduler = True else: model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader ) - scheduler = False - + use_scheduler = False # setup for saving training states in case preemption - if scheduler: + if use_scheduler: accelerator.register_for_checkpointing(scheduler) if config["checkpoint"]: @@ -167,24 +190,33 @@ def train(accelerator, config): for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): curr_step = step + epoch * len(train_dataloader) memories = batch["retrieved_context"] + memories = memories[:, :config["num_neighbors_to_store"], :] + memories = memories.reshape(-1, memories.shape[-1]) # need to set to eval so we don't do mem attn as it's slow model.eval() with torch.no_grad(): for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): - chunk_end = min(memories.shape[0], chunk_start + 32) + chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"]) mem_chunk = memories[chunk_start:chunk_end] - model(input_ids=mem_chunk, labels=None, use_mem_attn=False) + model(input_ids=mem_chunk) + + del memories + torch.cuda.empty_cache() model.train() qa_inputs = batch["input_ids"] qa_labels = batch["labels"] outputs = model(input_ids=qa_inputs, labels=qa_labels, + log_attn_scores=True, + step=curr_step, + save_kv=False, ) loss = outputs.loss + if config["wandb"]: + accelerator.log({"loss": loss}, step=curr_step) - index.reset() # gather loss before backprop in case of gradient accumulation loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) @@ -192,6 +224,8 @@ def train(accelerator, config): loss = loss / gradient_accumulation_steps accelerator.backward(loss) + # !! don't reset index until after backwards pass + index.reset() # get gradient norm of all params # log LR in case something weird happens @@ -202,15 +236,25 @@ def train(accelerator, config): if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: optimizer.step() - if scheduler: + if use_scheduler: scheduler.step() optimizer.zero_grad() if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0: - accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + # accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/step_{step}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): - val_loss = evaluate(model, config, val_dataloader, main_process=main_process) + val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process) + + local_rank = accelerator.process_index + feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather") log_train = { "train_loss": train_loss.compute() diff --git a/gpt4all/utils/distributed_utils.py b/gpt4all/utils/distributed_utils.py index 3fcf276f..da95bc9f 100644 --- a/gpt4all/utils/distributed_utils.py +++ b/gpt4all/utils/distributed_utils.py @@ -1,4 +1,5 @@ import torch.distributed as dist +from contextlib import contextmanager def rank0_print(msg): @@ -7,3 +8,20 @@ def rank0_print(msg): print(msg) else: print(msg) + + +@contextmanager +def main_process_first(is_main): + yield from _goes_first(is_main) + + + +def _goes_first(is_main): + if not is_main: + dist.barrier() + + yield + + if is_main: + dist.barrier() + \ No newline at end of file