diff --git a/configs/deepspeed/ds_config_pythiaseek.json b/configs/deepspeed/ds_config_pythiaseek.json index e6ec172d..8d84c8fd 100644 --- a/configs/deepspeed/ds_config_pythiaseek.json +++ b/configs/deepspeed/ds_config_pythiaseek.json @@ -14,7 +14,7 @@ }, "gradient_clipping": 1.0, "zero_optimization": { - "stage": 2, + "stage": 1, "offload_param": { "device": "none" }, @@ -35,5 +35,15 @@ ], "eps": 1e-08 } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "warmup_type": "linear", + "total_num_steps": "auto" + } } } \ No newline at end of file diff --git a/configs/eval/evaluate_lethe.yaml b/configs/eval/evaluate_lethe.yaml index 79c54ea6..6bf0a72a 100644 --- a/configs/eval/evaluate_lethe.yaml +++ b/configs/eval/evaluate_lethe.yaml @@ -1,9 +1,10 @@ # model/tokenizer -model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000" +model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/qk_no_norm/step_500" tokenizer_name: "EleutherAI/pythia-1b" version: null -gradient_checkpointing: false +gradient_checkpointing: true memory_attn_layer: 12 +seed: 42 # dataset diff --git a/configs/train/finetune_memory.yaml b/configs/train/finetune_memory.yaml index 791c065c..10719ba5 100644 --- a/configs/train/finetune_memory.yaml +++ b/configs/train/finetune_memory.yaml @@ -5,7 +5,7 @@ version: null gradient_checkpointing: true save_name: "nomic-ai/lethe" push_to_hub: false -memory_attn_layer: 12 +memory_attn_layer: [9, 12, 15] # dataset streaming: false @@ -17,7 +17,7 @@ pct_test: 0.05 q_column: "question" a_column: "answer" context_column: "text" -num_memories_per_index: 2000000 +num_memories_per_index: 2048 num_neighbors_to_retrieve: 2 num_neighbors_to_store: 1 mem_chunk_size: 64 @@ -26,15 +26,15 @@ mem_chunk_size: 64 lr: 1.0e-5 min_lr: 0 weight_decay: 0.0 -eval_every: 100 -save_every: 100 +eval_every: 250 +save_every: 250 log_grads_every: 100 log_lr_every: 10 -output_dir: "ckpts/mem_attn_no_cosine_sim" +output_dir: "ckpts/qk_no_norm" checkpoint: null lora: false warmup_steps: 200 -num_epochs: 5 +num_epochs: 2 debug: false scheduler: false @@ -42,5 +42,4 @@ scheduler: false wandb: true wandb_entity: gpt4all wandb_project_name: mem_attn -seed: 42 - +seed: 42 \ No newline at end of file diff --git a/configs/train/pretrain_minipile.yaml b/configs/train/pretrain_enwik8.yaml similarity index 79% rename from configs/train/pretrain_minipile.yaml rename to configs/train/pretrain_enwik8.yaml index 94b72bec..4323cea3 100644 --- a/configs/train/pretrain_minipile.yaml +++ b/configs/train/pretrain_enwik8.yaml @@ -10,20 +10,21 @@ memory_attn_layer: 12 # dataset streaming: false num_proc: 64 -dataset_path: "JeanKaddour/minipile" +dataset_path: "pg19" max_length: 2048 -batch_size: 64 +seq_len: 512 +segments: 16 +batch_size: 16 pct_test: 0.05 -num_memories_per_index: 5000000 +num_memories_per_index: 100000 mem_chunk_size: 512 -num_chunks: 10 num_neighbors_to_retrieve: 32 # train dynamics -lr: 1.0e-4 +lr: 2.0e-4 min_lr: 0 weight_decay: 0.0 -eval_every: 100 +eval_every: 250 save_every: -1 log_grads_every: 100 log_lr_every: 10 @@ -38,6 +39,6 @@ scheduler: false # logging wandb: true wandb_entity: gpt4all -wandb_project_name: minipile +wandb_project_name: enwik8 seed: 42 diff --git a/gpt4all/data/enwik8.py b/gpt4all/data/enwik8.py new file mode 100644 index 00000000..9c1a92f3 --- /dev/null +++ b/gpt4all/data/enwik8.py @@ -0,0 +1,62 @@ +import numpy as np +import torch +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from transformers import DefaultDataCollator + + + +class EnWik8Dataset(Dataset): + def __init__(self, data, seq_len): + # pyarrow chunked array + self.data = torch.from_numpy(data) + self.seq_len = seq_len + + def __getitem__(self, index): + full_seq = self.data[index].long() + return full_seq.cuda() + + def __len__(self): + return len(self.data) + + +def load_enwik8_dataloader(config, tokenizer): + ds = load_dataset(config["dataset_path"], split="train") + + ds = ds.train_test_split(test_size=0.05, seed=config['seed']) + + train_ds, val_ds = ds["train"], ds["test"] + + keep_cols = ["input_ids"] + + train_ds = train_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True) + train_ds = train_ds.sort("len") + train_ds = train_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"), + batched=True, + batch_size=config["batch_size"]) + + remove_cols = [col for col in train_ds.column_names if col not in keep_cols] + train_ds = train_ds.remove_columns(remove_cols) + + val_ds = val_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True) + val_ds = val_ds.sort("len") + val_ds = val_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"), + batched=True, + batch_size=config["batch_size"]) + + remove_cols = [col for col in train_ds.column_names if col not in keep_cols] + val_ds = val_ds.remove_columns(remove_cols) + + train_dl = DataLoader(train_ds, + batch_size=config["batch_size"], + shuffle=True, + drop_last=True, + collate_fn=DefaultDataCollator()) + + val_dl = DataLoader(val_ds, + batch_size=config["batch_size"], + shuffle=True, + drop_last=True, + collate_fn=DefaultDataCollator()) + + return train_dl, val_dl \ No newline at end of file diff --git a/gpt4all/data/instruction_tuning_dataloader.py b/gpt4all/data/instruction_tuning_dataloader.py index 5803ea76..eb554118 100644 --- a/gpt4all/data/instruction_tuning_dataloader.py +++ b/gpt4all/data/instruction_tuning_dataloader.py @@ -1,6 +1,6 @@ import glob import torch -from datasets import load_dataset +from datasets import load_dataset, load_from_disk import os import hnswlib from torch.utils.data import DataLoader @@ -12,20 +12,23 @@ def load_data(config, tokenizer): dataset_path = config["dataset_path"] if os.path.exists(dataset_path): - if os.path.isdir(dataset_path): - files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) - else: - files = [dataset_path] + dataset = load_from_disk(dataset_path) + # if os.path.isdir(dataset_path): + # files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) + # else: + # files = [dataset_path] - print(f"Reading files {files}") + # print(f"Reading files {files}") - dataset = load_dataset("json", data_files=files, split="train") + # dataset = load_dataset("json", data_files=files, split="train") else: dataset = load_dataset(dataset_path, split="train") dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) + dataset = dataset.map(lambda x: {"prompt": [text + " " + question for text, question in zip(x["text"], x["question"])]}, batched=True) + train_dataset, val_dataset = dataset["train"], dataset["test"] if config["streaming"] is False: @@ -33,19 +36,27 @@ def load_data(config, tokenizer): else: kwargs = {} + cols_to_keep = ["input_ids", "labels", "attention_mask"] + # tokenize inputs and return labels and attention mask train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"), batched=True, - remove_columns=["source", "prompt"], + # remove_columns=["source", "prompt"], **kwargs ) + + cols_to_remove = [col for col in train_dataset.column_names if col not in cols_to_keep] + train_dataset = train_dataset.remove_columns(cols_to_remove) + val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"), batched=True, - remove_columns=["source", "prompt"], + # remove_columns=["source", "prompt"], **kwargs ) + cols_to_remove = [col for col in val_dataset.column_names if col not in cols_to_keep] + val_dataset = val_dataset.remove_columns(cols_to_remove) train_dataset = train_dataset.with_format("torch") val_dataset = val_dataset.with_format("torch") @@ -56,12 +67,14 @@ def load_data(config, tokenizer): train_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], + shuffle=True, ) val_dataloader = DataLoader( val_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], + shuffle=True, ) return train_dataloader, val_dataloader diff --git a/gpt4all/data/preprocess.py b/gpt4all/data/preprocess.py index 6f5c2e26..703baf1f 100644 --- a/gpt4all/data/preprocess.py +++ b/gpt4all/data/preprocess.py @@ -5,7 +5,7 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col): # hacky backward compatible different_eos = tokenizer.eos_token != "" - out = {"labels": [], "input_ids": []} + out = {"labels": [], "input_ids": [], "attention_mask": []} for prompt, response in zip(examples[input_col], examples[target_col]): if different_eos: if response.count(" \n") > 0: @@ -42,9 +42,10 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col): print(response) raise - input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] + tokenized = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length) out["labels"].append(labels) - out["input_ids"].append(input_tokens) + out["input_ids"].append(tokenized["input_ids"]) + out["attention_mask"].append(tokenized["attention_mask"]) out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} diff --git a/gpt4all/data/retrieval_dataloader.py b/gpt4all/data/retrieval_dataloader.py index 3a4dbacb..01e1379c 100644 --- a/gpt4all/data/retrieval_dataloader.py +++ b/gpt4all/data/retrieval_dataloader.py @@ -41,7 +41,7 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase if encoder_column != "encoder_hidden_states": dataset = dataset.rename_column(encoder_column, "encoder_hidden_states") - columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"] + columns_to_keep = ["input_ids", "attention_mask", "labels", "encoder_hidden_states"] col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] dataset = dataset.remove_columns(col_names_to_rm) @@ -115,7 +115,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T **kwargs ) - columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"] + columns_to_keep = ["id", "input_ids", "attention_mask", "labels", "retrieved_context"] col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] dataset = dataset.remove_columns(col_names_to_rm) @@ -128,12 +128,16 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T train_dataset.remove_columns("id"), batch_size=config["batch_size"], collate_fn=DefaultDataCollator(), + shuffle=True, + drop_last=True, ) val_dataloader = DataLoader( val_dataset, batch_size=config["batch_size"], collate_fn=DefaultDataCollator(), + shuffle=True, + drop_last=True, ) return train_dataloader, val_dataloader diff --git a/gpt4all/eval/eval_squad_atlas_map.py b/gpt4all/eval/eval_squad_atlas_map.py index 9480ec28..e2d61d0b 100644 --- a/gpt4all/eval/eval_squad_atlas_map.py +++ b/gpt4all/eval/eval_squad_atlas_map.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F from gpt4all.models import LetheForCausalLM -from gpt4all.models.lethe.modeling_lethe import MemoryIndex +from gpt4all.models.lethe.modeling_lethe import BatchedMemory from gpt4all.data.retrieval_dataloader import load_memory_augmented_data from gpt4all.train.metrics import f1_score, exact_match_score from gpt4all.utils.read import read_config @@ -28,17 +28,21 @@ def greedy_search(input_ids, model, tokenizer, max_new_tokens=100): while True: if num_new_tokens >= max_new_tokens: break - outputs = model(input_ids, save_kv=False) + attention_mask = input_ids.ne(tokenizer.pad_token_id) + outputs = model(input_ids, attention_mask=attention_mask, save_kv=False) - new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1) + next_token_idx = torch.argmax((input_ids == tokenizer.pad_token_id).type(torch.float32)) + # -1 because logits at last position predict next token + new_token = torch.argmax(outputs.logits[:, next_token_idx - 1, :], dim=-1) + + input_ids[:, next_token_idx] = new_token - input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1) num_new_tokens += 1 - if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)): + if torch.equal(new_token.cpu(), torch.tensor(tokenizer.eos_token_id)): break - print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)) + print(f"GENERATED: {tokenizer.batch_decode(input_ids, skip_special_tokens=True)}") return input_ids @@ -63,10 +67,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_config = AutoConfig.from_pretrained(config["model_name"]) head_size = model_config.hidden_size // model_config.num_attention_heads -index = MemoryIndex(head_size, - config["num_memories_per_index"], - model_config.num_attention_heads -) +index = BatchedMemory(config["batch_size"], + head_size, + config["num_memories_per_index"], + model_config.num_attention_heads, + ) model = LetheForCausalLM.from_pretrained(config["model_name"], revision=config['version'] if 'version' in config else None, memory_attn_layer=config["memory_attn_layer"], @@ -90,16 +95,19 @@ with torch.no_grad(): mem_chunk = memories[chunk_start:chunk_end] model(input_ids=mem_chunk.to(device)) - del memories torch.cuda.empty_cache() qa_inputs = batch["input_ids"] qa_labels = batch["labels"] for i in range(qa_inputs.shape[0]): inputs = qa_inputs[i].to(device) + print(f"EXPECTED: {tokenizer.decode(inputs, skip_special_tokens=True)}") labels = qa_labels[i].to(device) + cutoff = torch.argmax((labels != -100).type(torch.float32)) - greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer) - print(tokenizer.decode(inputs, skip_special_tokens=True)) + inputs[cutoff:] = tokenizer.pad_token_id + greedy_search(inputs.unsqueeze(0).to(device), model, tokenizer) + print(f"CONTEXT: {tokenizer.decode(memories[i], skip_special_tokens=True)}") + import pdb; pdb.set_trace() # batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device)) diff --git a/gpt4all/eval/eval_squad_atlas_map/dataset_info.json b/gpt4all/eval/eval_squad_atlas_map/dataset_info.json new file mode 100644 index 00000000..ad25412c --- /dev/null +++ b/gpt4all/eval/eval_squad_atlas_map/dataset_info.json @@ -0,0 +1,96 @@ +{ + "builder_name": "squad", + "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", + "config_name": "plain_text", + "dataset_size": 89819092, + "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", + "download_checksums": { + "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": { + "num_bytes": 30288272, + "checksum": null + }, + "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": { + "num_bytes": 4854279, + "checksum": null + } + }, + "download_size": 35142551, + "features": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "title": { + "dtype": "string", + "_type": "Value" + }, + "context": { + "dtype": "string", + "_type": "Value" + }, + "question": { + "dtype": "string", + "_type": "Value" + }, + "answers": { + "feature": { + "text": { + "dtype": "string", + "_type": "Value" + }, + "answer_start": { + "dtype": "int32", + "_type": "Value" + } + }, + "_type": "Sequence" + }, + "neighbor_ids": { + "feature": { + "dtype": "uint64", + "_type": "Value" + }, + "_type": "Sequence" + }, + "neighbor_text": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "loss": { + "dtype": "float64", + "_type": "Value" + } + }, + "homepage": "https://rajpurkar.github.io/SQuAD-explorer/", + "license": "", + "size_in_bytes": 124961643, + "splits": { + "train": { + "name": "train", + "num_bytes": 79346108, + "num_examples": 87599, + "dataset_name": "squad" + }, + "validation": { + "name": "validation", + "num_bytes": 10472984, + "num_examples": 10570, + "dataset_name": "squad" + } + }, + "task_templates": [ + { + "task": "question-answering-extractive" + } + ], + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} \ No newline at end of file diff --git a/gpt4all/eval/eval_squad_atlas_map/state.json b/gpt4all/eval/eval_squad_atlas_map/state.json new file mode 100644 index 00000000..a83a9654 --- /dev/null +++ b/gpt4all/eval/eval_squad_atlas_map/state.json @@ -0,0 +1,16 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00002.arrow" + }, + { + "filename": "data-00001-of-00002.arrow" + } + ], + "_fingerprint": "c178f5c8269012a2", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "validation" +} \ No newline at end of file diff --git a/gpt4all/eval/eval_synthetic.py b/gpt4all/eval/eval_synthetic.py new file mode 100644 index 00000000..af1fa623 --- /dev/null +++ b/gpt4all/eval/eval_synthetic.py @@ -0,0 +1,42 @@ +import torch +from gpt4all.data.instruction_tuning_dataloader import load_data +from gpt4all.utils.read import read_config +from transformers import AutoTokenizer, AutoModelForCausalLM +from argparse import ArgumentParser +from tqdm import tqdm + +parser = ArgumentParser() +parser.add_argument("--config", type=str, default="config.yaml") + +args = parser.parse_args() + +config = read_config(args.config) + +tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"]) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +train_dataloader, val_dataloader = load_data(config, tokenizer) + + +model = AutoModelForCausalLM.from_pretrained(config["model_name"], + trust_remote_code=True) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = model.half().to(device) +model.eval() + +# Evaluate the model on the SQUAD dataset +f1s = [] +exact_matches = [] +with torch.no_grad(): + for batch in tqdm(val_dataloader): + inputs = batch["input_ids"].to(device) + labels = batch["labels"].to(device) + cutoff = torch.argmax((labels != -100).type(torch.float32)) + outputs = model.generate(inputs[:, :cutoff], max_new_tokens=100) + print(f"Predicted: {tokenizer.batch_decode(outputs, skip_special_tokens=True)}") + print(f"Ground truth: {tokenizer.batch_decode(inputs[:, cutoff:], skip_special_tokens=True)}") + print(tokenizer.batch_decode(inputs, skip_special_tokens=True)) + diff --git a/gpt4all/models/__init__.py b/gpt4all/models/__init__.py index 575d8e2f..8ee09794 100644 --- a/gpt4all/models/__init__.py +++ b/gpt4all/models/__init__.py @@ -2,7 +2,6 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig -from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig from .lethe import LetheConfig, LetheForCausalLM diff --git a/gpt4all/models/lethe/modeling_lethe.py b/gpt4all/models/lethe/modeling_lethe.py index 80d658af..cdf5363e 100644 --- a/gpt4all/models/lethe/modeling_lethe.py +++ b/gpt4all/models/lethe/modeling_lethe.py @@ -18,6 +18,8 @@ import wandb import math import torch.nn.functional as F import matplotlib.pyplot as plt +import plotly.express as px +import pandas as pd from typing import Optional, Tuple, Union import torch @@ -121,60 +123,106 @@ class MemoryIndex: # NOTE: we are storing kv pairs, instead indices for both keys and values self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)] - shape = (num_mems, nheads, 2, hidden_dim) + shape = (nheads, num_mems, 2, hidden_dim) self.kv_pairs = np.zeros(shape, dtype=np.float32) self.idx_offset = 0 def add(self, keys, values): - # k/v are (bs, num_attention_heads, seq_len, head_size) - reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3]) - reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3]) + # k/v are (num_attention_heads, seq_len, head_size) + # keys = keys.reshape(keys.shape[1], keys.shape[0], keys.shape[2]) + # values = values.reshape(values.shape[1], values.shape[0], values.shape[2]) for head in range(self.nheads): - self.key_indices[head].add(reshaped_keys[:, head, :]) + self.key_indices[head].add(keys[head, :, :]) - kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2) + kv_pairs = np.stack((keys, values), axis=2) - if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]: - raise ValueError("Not enough memory!") + if self.idx_offset + kv_pairs.shape[1] > self.kv_pairs.shape[1]: + # reset to 0 to overwrite oldest memories + self.idx_offet = 0 - self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs - self.idx_offset += kv_pairs.shape[0] + self.kv_pairs[:, self.idx_offset:self.idx_offset + kv_pairs.shape[1]] = kv_pairs + self.idx_offset += kv_pairs.shape[1] def knn_query(self, query, k=1): - reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3]) - mem_keys = [] mem_values = [] mem_indices = [] # we can prob make this better for head in range(self.nheads): - knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k) - kv_pairs = self.kv_pairs[:, head, :, :][knn_indices] + knn_indices = self.key_indices[head].query(query[head, :, :], k=k) + kv_pairs = self.kv_pairs[head, :, :, :][knn_indices] mem_keys.append(kv_pairs[:, :, 0, :]) mem_values.append(kv_pairs[:, :, 1, :]) mem_indices.append(knn_indices) - mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1)) - # (bs, num_attention_heads, seq_len, k, head_size) - mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],)) + mem_keys = torch.from_numpy(np.stack(mem_keys, axis=0)) + # (num_attention_heads, seq_len, k, head_size) + # mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],)) - mem_values = torch.from_numpy(np.stack(mem_values, axis=1)) - # (bs, num_attention_heads, seq_len, k, head_size) - mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],)) + mem_values = torch.from_numpy(np.stack(mem_values, axis=0)) + # (num_attention_heads, seq_len, k, head_size) + # mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],)) - return mem_keys, mem_values, np.stack(mem_indices, axis=1) + return mem_keys, mem_values, np.stack(mem_indices, axis=0) def reset(self): for head in range(self.nheads): self.key_indices[head].reset() - self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32) + self.kv_pairs = np.zeros(self.kv_pairs.shape, dtype=np.float32) + self.idx_offset = 0 +class BatchedMemory: + def __init__(self, batch_size, hidden_dim, num_mems, nheads): + self.indices = [MemoryIndex(hidden_dim, num_mems, nheads) for _ in range(batch_size)] + + + def add(self, keys, values): + for bs in range(len(self.indices)): + self.indices[bs].add(keys[bs], values[bs]) + + + def knn_query(self, query, k=1): + batched_mem_keys = [] + batched_mem_values = [] + batched_labels = [] + + for bs in range(len(self.indices)): + knn_keys, knn_values, knn_labels = self.indices[bs].knn_query(query[bs], k=k) + batched_mem_keys.append(knn_keys) + batched_mem_values.append(knn_values) + batched_labels.append(knn_labels) + + + return torch.stack(batched_mem_keys, dim=0), torch.stack(batched_mem_values, dim=0), np.stack(batched_labels, axis=0) + + def reset(self): + for bs in range(len(self.indices)): + self.indices[bs].reset() + + +class BatchedStorage: + def __init__(self, batch_size, hidden_dim, nheads, seq_len): + self.indices = np.zeros((batch_size, nheads, seq_len, 2, hidden_dim), dtype=np.float32) + + def knn_query(self, query, k=None): + return torch.from_numpy(self.indices[:, :, :, 0, :]), torch.from_numpy(self.indices[:, :, :, 1, :]) + + + def add(self, keys, values): + self.indices[:, :, :, 0, :] = keys + self.indices[:, :, :, 1, :] = values + + def reset(self): + self.indices = np.zeros_like(self.indices) + + + class LethePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -206,12 +254,13 @@ class LethePreTrainedModel(PreTrainedModel): class LetheAttention(nn.Module): - def __init__(self, config, memory_attention=False, index=None, tracker=None): + def __init__(self, config, memory_attention=False, index=None, layer_idx=None, tracker=None): super().__init__() self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.layer_idx = layer_idx max_positions = config.max_position_embeddings self.register_buffer( "bias", @@ -274,7 +323,6 @@ class LetheAttention(nn.Module): key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) - # if self.memory: if self.memory: # QKNorm: https://arxiv.org/abs/2010.04245 query = F.normalize(query, dim=-1) @@ -309,17 +357,28 @@ class LetheAttention(nn.Module): self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy()) knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors) + # if log_attn_scores: + # batch_size = query.shape[0] + # seq_len = query.shape[-2] + + # key_labels = knn_labels // seq_len + # key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1) + # correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis]) + # # calculate the accuracy + # key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape) + + # self.tracker.log({"retrieved_acc": key_acc}, step=step) + if log_attn_scores: - batch_size = query.shape[0] - seq_len = query.shape[-2] + total_examples = 0 + unique_examples = 0 + for bs in range(query.shape[0]): + for head in range(query.shape[1]): + labels_per_head = knn_labels[bs, head, :, 0].tolist() + total_examples += len(labels_per_head) + unique_examples += len(set(labels_per_head)) - key_labels = knn_labels // seq_len - key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1) - correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis]) - # calculate the accuracy - key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape) - - self.tracker.log({"retrieved_acc": key_acc}, step=step) + self.tracker.log({"unique_retrieved_pct": unique_examples / total_examples}, step=step) attn_output = self._mem_attn(query, knn_keys.to(query.device).to(value.dtype), @@ -407,6 +466,7 @@ class LetheAttention(nn.Module): local_attn_scores = local_attn_scores + attention_mask mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key) + # mem_attn_scores = torch.matmul(query, knn_key.transpose(-1, -2)) # attn_scores: [bs, seq_len, num_attention_heads, knn] mem_attn_scores = mem_attn_scores * scale @@ -417,48 +477,56 @@ class LetheAttention(nn.Module): attn_weights = attn_weights.to(local_value.dtype) mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1) - if log_attn_scores: - # (bs, seq_len, num_attention_heads, knn) probabilities - # curate (x,y) pairs - # where x is attention weight, y is accuracy of retrieved token - bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2) - key_labels = knn_labels // seq_len - key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1) - correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis]) + # mem_attn_weights, local_attn_weights = attn_weights.chunk(2, dim=-1) - bin_width = 0.05 - # Calculate the number of bins - num_bins = int(1 / bin_width) + # if log_attn_scores: + # # (bs, seq_len, num_attention_heads, knn) probabilities + # # curate (x,y) pairs + # # where x is attention weight, y is accuracy of retrieved token + # bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2) + # key_labels = knn_labels // seq_len + # key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1) + # correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis]) - # Create empty lists for storing bin probabilities and accuracies - bin_probabilities = [] - bin_accuracies = [] + # bin_width = 0.05 - probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist() - correct_keys = correct_keys.reshape(-1).tolist() + # # Calculate the number of bins + # num_bins = int(1 / bin_width) - # Iterate over each bin - for i in range(num_bins): - bin_lower = i * bin_width - bin_upper = (i + 1) * bin_width + # # Create empty lists for storing bin probabilities and accuracies + # bin_probabilities = [] + # bin_accuracies = [] + # bin_sizes = [] - # Filter data points within the current bin range - bin_x_values = [x for x in probs if bin_lower <= x < bin_upper] - bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper] + # probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist() + # correct_keys = correct_keys.reshape(-1).tolist() - # Calculate accuracy for the bin - total = len(bin_x_values) - correct = sum(bin_y_values) - accuracy = correct / total if total > 0 else 0 + # # Iterate over each bin + # for i in range(num_bins): + # bin_lower = i * bin_width + # bin_upper = (i + 1) * bin_width - # Store the probability and accuracy for the bin - bin_probabilities.append((bin_lower + bin_upper) / 2) - bin_accuracies.append(accuracy) + # # Filter data points within the current bin range + # bin_x_values = [x for x in probs if bin_lower <= x < bin_upper] + # bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper] - data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)] - table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"]) - self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step) + # # Calculate accuracy for the bin + # total = len(bin_x_values) + # correct = sum(bin_y_values) + # accuracy = correct / total if total > 0 else 0 + + # # Store the probability and accuracy for the bin + # bin_probabilities.append((bin_lower + bin_upper) / 2) + # bin_accuracies.append(accuracy) + # bin_sizes.append(len(bin_x_values)) + + # df = pd.DataFrame({"attn_prob": bin_probabilities, "retrieved_acc": bin_accuracies, "bin_size": bin_sizes}) + + # fig = px.scatter(df, x="attn_prob", y="retrieved_acc", + # color="bin_size", hover_data=["attn_prob", "retrieved_acc", "bin_size"], + # title="Attention Probability vs Retrieved Accuracy") + # self.tracker.log({"attn_vs_acc": fig}, step=step) if log_attn_scores: @@ -470,10 +538,10 @@ class LetheAttention(nn.Module): mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1) mem_bins = torch.linspace(0, 1, steps=20 + 1) plt.stairs(mem_hist.tolist(), mem_bins.tolist()) - plt.title(f"mem_attn_score_{head}") + plt.title(f"mem_attn_score_{head}_layer_{self.layer_idx}") # set arbitrarily but we want to see those peaks!! plt.ylim((0, 1000)) - self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step) + self.tracker.log({f"mem_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step) plt.close() @@ -482,15 +550,16 @@ class LetheAttention(nn.Module): local_hist = torch.histc(local_flat, bins=20, min=0, max=1) local_bins = torch.linspace(0, 1, steps=20 + 1) plt.stairs(local_hist.tolist(), local_bins.tolist()) - plt.title(f"local_attn_score_{head}") + plt.title(f"local_attn_score_{head}_layer_{self.layer_idx}") # set arbitrarily but we want to see those peaks!! plt.ylim((0, 1000)) - self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step) + self.tracker.log({f"local_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step) plt.close() # attn_output: [bs, num_attention_heads, seq_len, attn_head_size] mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value) + # mem_attn_output = torch.matmul(mem_attn_weights, knn_value) local_attn_output = torch.matmul(local_attn_weights, local_value) # TODO: do we need flamingo style gating @@ -524,8 +593,6 @@ class LetheAttention(nn.Module): alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor) ) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) - if self.memory: - attn_scores = attn_scores * self.scale.exp() mask_value = torch.finfo(attn_scores.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. @@ -610,12 +677,15 @@ class LetheMLP(nn.Module): class LetheLayer(nn.Module): - def __init__(self, config, memory_attention=False, index=None, tracker=None): + def __init__(self, config, memory_attention=False, layer_idx=None, index=None, tracker=None): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker) + self.attention = LetheAttention(config, memory_attention=memory_attention, + layer_idx=layer_idx, + index=index[layer_idx] if memory_attention else None, + tracker=tracker) self.mlp = LetheMLP(config) def forward( @@ -676,8 +746,9 @@ class LetheModel(LethePreTrainedModel): self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([LetheLayer(config, - memory_attention=i+1 == config.memory_attn_layer, - index=index if i+1 == config.memory_attn_layer else None, + memory_attention=i+1 in config.memory_attn_layer, + layer_idx=i, + index=index, tracker=tracker) for i in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/gpt4all/models/pythia_retro/__init__.py b/gpt4all/models/pythia_retro/__init__.py deleted file mode 100644 index e58e7069..00000000 --- a/gpt4all/models/pythia_retro/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .configuration_pythia_retro import PythiaRetroConfig -from .modeling_pythia_retro import PythiaRetroForCausalLM \ No newline at end of file diff --git a/gpt4all/train/pretrain_mem_retrieval.py b/gpt4all/train/pretrain_mem_retrieval.py new file mode 100644 index 00000000..85df6a24 --- /dev/null +++ b/gpt4all/train/pretrain_mem_retrieval.py @@ -0,0 +1,288 @@ +import os +import torch.nn.functional as F +from transformers import AutoTokenizer, get_scheduler, AutoConfig +import torch +from torch.optim import AdamW +from argparse import ArgumentParser +from gpt4all.utils.read import read_config +from accelerate import Accelerator +from accelerate.utils import DummyScheduler, DummyOptim, set_seed +from gpt4all.data.enwik8 import load_enwik8_dataloader +from torchmetrics import MeanMetric +from tqdm import tqdm +from gpt4all.models import LetheForCausalLM, LetheConfig +from gpt4all.models.lethe.modeling_lethe import BatchedMemory +import wandb + +torch.backends.cuda.matmul.allow_tf32 = True + +def format_metrics(metrics, split, prefix=""): + log = f"[{split}]" + prefix + log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()]) + + return log + + +def evaluate(model, index, pad_token_id, config, val_dataloader, main_process=False): + model.eval() + val_loss = MeanMetric(nan_strategy="error").to(model.device) + + chunk_size = config["seq_len"] + with torch.no_grad(): + for batch in tqdm(val_dataloader, disable=not main_process): + seq_len = batch.shape[1] + for chunk_start in range(0, seq_len, chunk_size): + chunk_end = min(seq_len, chunk_start + chunk_size) + inputs = batch[:, chunk_start:chunk_end].to(model.device) + labels = inputs.clone() + outputs = model(input_ids=inputs, + attention_mask=inputs.ne(pad_token_id), + labels=labels, + log_attn_scores=False, + step=None, + save_kv=True, + ) + loss = outputs.loss / config["segments"] + loss_values = accelerator.gather_for_metrics({"loss": loss.item()}) + val_loss.update(loss_values["loss"]) + + index.reset() + + return val_loss + + +def train(accelerator, config): + set_seed(config['seed']) + + accelerator.print(config) + accelerator.print(f"Using {accelerator.num_processes} GPUs") + + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + # if no pad token, set it to eos + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + + with accelerator.main_process_first(): + train_dataloader, val_dataloader = load_enwik8_dataloader(config, tokenizer) + + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + + accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") + total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] + # instead of decaying to zero, decay to ratio of min_lr / lr + accelerator.print(f"Total training steps: {total_num_steps}") + + checkpoint = config["gradient_checkpointing"] + + model_config = LetheConfig.from_pretrained(config["model_name"]) + model_config.memory_attn_layer = config["memory_attn_layer"] + model_config.num_neighbors_to_retrieve = config["num_neighbors_to_retrieve"] + model_config.use_cache = False if checkpoint else True + + head_size = model_config.hidden_size // model_config.num_attention_heads + index = BatchedMemory(config["batch_size"], + head_size, + config["num_memories_per_index"], + model_config.num_attention_heads, + ) + + model = LetheForCausalLM(model_config, + index=index, + tracker=accelerator.get_tracker("wandb")) + + + accelerator.print(f"Training a {model.num_parameters():,} parameter model") + if checkpoint: + model.gradient_checkpointing_enable() + + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + + # karpathy doesn't decay embeddding, maybe we should exclude + # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s + optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) + + + # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler + if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config: + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + scheduler = get_scheduler( + name="cosine", + optimizer=optimizer, + num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, + num_training_steps=total_num_steps, + ) + else: + scheduler = DummyScheduler( + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] + ) + model, optimizer, scheduler, train_dataloader, val_dataloader = accelerator.prepare( + model, optimizer, scheduler, train_dataloader, val_dataloader + ) + use_scheduler = True + else: + model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader + ) + use_scheduler = False + + # setup for saving training states in case preemption + if use_scheduler: + accelerator.register_for_checkpointing(scheduler) + + if config["checkpoint"]: + accelerator.load_state(config["checkpoint"]) + accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") + path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) + training_difference = os.path.splitext(path)[0] + resume_step = int(training_difference.replace("step_", "")) + accelerator.skip_first_batches(train_dataloader, resume_step) + accelerator.print(f"Resuming from step {resume_step}") + + # log gradients + if accelerator.is_main_process and config["wandb"]: + wandb.watch(model, log_freq=config["log_grads_every"], log="all") + + main_process = accelerator.is_main_process + + chunk_size = config["seq_len"] + for epoch in range(config["num_epochs"]): + train_loss = MeanMetric(nan_strategy="error").to(model.device) + for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): + epoch_step = epoch * len(train_dataloader) + step * config["segments"] + seq_len = batch["input_ids"].shape[1] + model.train() + for i, chunk_start in enumerate(range(0, seq_len, chunk_size)): + curr_step = epoch_step + i + chunk_end = min(seq_len, chunk_start + chunk_size) + inputs = batch["input_ids"][:, chunk_start:chunk_end] + labels = inputs.clone() + labels[labels == tokenizer.pad_token_id] = -100 + outputs = model(input_ids=inputs, + attention_mask=inputs.ne(tokenizer.pad_token_id), + labels=labels, + log_attn_scores=True, + step=curr_step, + save_kv=True, + ) + loss = outputs.loss / config["segments"] + + if config["wandb"]: + accelerator.log({"loss": loss}, step=curr_step) + + # gather loss before backprop in case of gradient accumulation + loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) + train_loss.update(loss_values["loss"]) + + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + + # log LR in case something weird happens + if config["wandb"]: + if step > 0 and step % (config["log_lr_every"] ) == 0: + lr = optimizer.param_groups[0]["lr"] + accelerator.log({"lr": lr}, step=curr_step) + + optimizer.step() + if use_scheduler: + scheduler.step() + optimizer.zero_grad() + + # reset index on batch end + index.reset() + + if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0: + # accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/step_{step}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): + val_loss = evaluate(model, index, tokenizer.pad_token_id, config, val_dataloader, main_process=main_process) + + log_train = { + "train_loss": train_loss.compute() + } + log_val = { + "val_loss": val_loss.compute(), + } + + if config["wandb"]: + curr_step = step + epoch * len(train_dataloader) + accelerator.log({**log_train, **log_val}, step=curr_step) + + accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}") + accelerator.print(format_metrics(log_train, "train", f" step {step} ")) + accelerator.print(format_metrics(log_val, "val", f" step {step} ")) + + train_loss.reset() + + accelerator.print(f"Epoch {epoch} finished") + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + if config["push_to_hub"]: + accelerator.print(f"Pushing to HF hub") + try: + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + + except Exception as e: + accelerator.print(e) + accelerator.print(f"Failed to push to hub") + + unwrapped_model.save_pretrained( + f"{config['output_dir']}/epoch_{epoch}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/final", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.end_training() + + + +if __name__ == "__main__": + # parse arguments by reading in a config + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + if config["wandb"]: + accelerator = Accelerator(log_with="wandb") + accelerator.init_trackers( + project_name=config["wandb_project_name"], + config=config, + init_kwargs={"wandb": {"entity": config["wandb_entity"]}}, + ) + else: + accelerator = Accelerator() + + train(accelerator, config=config) diff --git a/gpt4all/train/train.py b/gpt4all/train/train.py index 75d8eda3..0a0f73b8 100644 --- a/gpt4all/train/train.py +++ b/gpt4all/train/train.py @@ -116,6 +116,7 @@ def train(accelerator, config): if config["checkpoint"]: accelerator.load_state(config["checkpoint"]) + import pdb; pdb.set_trace() accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) training_difference = os.path.splitext(path)[0] @@ -131,9 +132,12 @@ def train(accelerator, config): for epoch in range(config["num_epochs"]): train_loss = MeanMetric(nan_strategy="error").to(model.device) for step, batch in enumerate(tqdm(train_dataloader)): + curr_step = step + epoch * len(train_dataloader) model.train() outputs = model(**batch) loss = outputs.loss + if config["wandb"]: + accelerator.log({"loss": loss}, step=curr_step) # gather loss before backprop in case of gradient accumulation loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) @@ -157,7 +161,13 @@ def train(accelerator, config): if step > 0 and step % config["save_every"] == 0: curr_step = step + epoch * len(train_dataloader) - accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/step_{curr_step}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): val_loss = evaluate(model, val_dataloader) diff --git a/gpt4all/train/train_mem_retrieval.py b/gpt4all/train/train_mem_retrieval.py index 864e1dd7..b58b0ff8 100644 --- a/gpt4all/train/train_mem_retrieval.py +++ b/gpt4all/train/train_mem_retrieval.py @@ -11,7 +11,7 @@ from gpt4all.data.retrieval_dataloader import load_memory_augmented_data from torchmetrics import MeanMetric from tqdm import tqdm from gpt4all.models import LetheForCausalLM -from gpt4all.models.lethe.modeling_lethe import MemoryIndex +from gpt4all.models.lethe.modeling_lethe import BatchedMemory import wandb import pyarrow as pa from pyarrow import feather @@ -58,13 +58,15 @@ def evaluate(model, index, config, val_dataloader, main_process=False): qa_labels = batch["labels"] outputs = model(input_ids=qa_inputs, labels=qa_labels, + save_kv=False ) del memories torch.cuda.empty_cache() loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) - index.reset() + for ind in index.values(): + ind.reset() val_loss.update(loss_values["loss"]) per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels) @@ -110,16 +112,18 @@ def train(accelerator, config): model_config = AutoConfig.from_pretrained(config["model_name"]) head_size = model_config.hidden_size // model_config.num_attention_heads - index = MemoryIndex(head_size, + indices = {i - 1: BatchedMemory(config["batch_size"], + head_size, config["num_memories_per_index"], - model_config.num_attention_heads - ) + model_config.num_attention_heads, + ) for i in config["memory_attn_layer"]} + model = LetheForCausalLM.from_pretrained(config["model_name"], revision=config['version'] if 'version' in config else None, use_cache=False if checkpoint else True, memory_attn_layer=config["memory_attn_layer"], num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"], - index=index, + index=indices, tracker=accelerator.get_tracker("wandb"), ) @@ -206,8 +210,10 @@ def train(accelerator, config): model.train() qa_inputs = batch["input_ids"] + attn_mask = batch["attention_mask"] qa_labels = batch["labels"] outputs = model(input_ids=qa_inputs, + attention_mask=attn_mask, labels=qa_labels, log_attn_scores=True, step=curr_step, @@ -225,7 +231,8 @@ def train(accelerator, config): loss = loss / gradient_accumulation_steps accelerator.backward(loss) # !! don't reset index until after backwards pass - index.reset() + for index in indices.values(): + index.reset() # get gradient norm of all params # log LR in case something weird happens @@ -251,7 +258,7 @@ def train(accelerator, config): ) if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): - val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process) + val_loss, loss_table = evaluate(model, indices, config, val_dataloader, main_process=main_process) local_rank = accelerator.process_index feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")