This commit is contained in:
Zach Nussbaum 2023-07-27 17:00:02 +00:00
parent 55fef489ad
commit 3128db96ca
18 changed files with 761 additions and 135 deletions

View File

@ -14,7 +14,7 @@
}, },
"gradient_clipping": 1.0, "gradient_clipping": 1.0,
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 1,
"offload_param": { "offload_param": {
"device": "none" "device": "none"
}, },
@ -35,5 +35,15 @@
], ],
"eps": 1e-08 "eps": 1e-08
} }
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
} }
} }

View File

@ -1,9 +1,10 @@
# model/tokenizer # model/tokenizer
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000" model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/qk_no_norm/step_500"
tokenizer_name: "EleutherAI/pythia-1b" tokenizer_name: "EleutherAI/pythia-1b"
version: null version: null
gradient_checkpointing: false gradient_checkpointing: true
memory_attn_layer: 12 memory_attn_layer: 12
seed: 42
# dataset # dataset

View File

@ -5,7 +5,7 @@ version: null
gradient_checkpointing: true gradient_checkpointing: true
save_name: "nomic-ai/lethe" save_name: "nomic-ai/lethe"
push_to_hub: false push_to_hub: false
memory_attn_layer: 12 memory_attn_layer: [9, 12, 15]
# dataset # dataset
streaming: false streaming: false
@ -17,7 +17,7 @@ pct_test: 0.05
q_column: "question" q_column: "question"
a_column: "answer" a_column: "answer"
context_column: "text" context_column: "text"
num_memories_per_index: 2000000 num_memories_per_index: 2048
num_neighbors_to_retrieve: 2 num_neighbors_to_retrieve: 2
num_neighbors_to_store: 1 num_neighbors_to_store: 1
mem_chunk_size: 64 mem_chunk_size: 64
@ -26,15 +26,15 @@ mem_chunk_size: 64
lr: 1.0e-5 lr: 1.0e-5
min_lr: 0 min_lr: 0
weight_decay: 0.0 weight_decay: 0.0
eval_every: 100 eval_every: 250
save_every: 100 save_every: 250
log_grads_every: 100 log_grads_every: 100
log_lr_every: 10 log_lr_every: 10
output_dir: "ckpts/mem_attn_no_cosine_sim" output_dir: "ckpts/qk_no_norm"
checkpoint: null checkpoint: null
lora: false lora: false
warmup_steps: 200 warmup_steps: 200
num_epochs: 5 num_epochs: 2
debug: false debug: false
scheduler: false scheduler: false
@ -43,4 +43,3 @@ wandb: true
wandb_entity: gpt4all wandb_entity: gpt4all
wandb_project_name: mem_attn wandb_project_name: mem_attn
seed: 42 seed: 42

View File

@ -10,20 +10,21 @@ memory_attn_layer: 12
# dataset # dataset
streaming: false streaming: false
num_proc: 64 num_proc: 64
dataset_path: "JeanKaddour/minipile" dataset_path: "pg19"
max_length: 2048 max_length: 2048
batch_size: 64 seq_len: 512
segments: 16
batch_size: 16
pct_test: 0.05 pct_test: 0.05
num_memories_per_index: 5000000 num_memories_per_index: 100000
mem_chunk_size: 512 mem_chunk_size: 512
num_chunks: 10
num_neighbors_to_retrieve: 32 num_neighbors_to_retrieve: 32
# train dynamics # train dynamics
lr: 1.0e-4 lr: 2.0e-4
min_lr: 0 min_lr: 0
weight_decay: 0.0 weight_decay: 0.0
eval_every: 100 eval_every: 250
save_every: -1 save_every: -1
log_grads_every: 100 log_grads_every: 100
log_lr_every: 10 log_lr_every: 10
@ -38,6 +39,6 @@ scheduler: false
# logging # logging
wandb: true wandb: true
wandb_entity: gpt4all wandb_entity: gpt4all
wandb_project_name: minipile wandb_project_name: enwik8
seed: 42 seed: 42

62
gpt4all/data/enwik8.py Normal file
View File

@ -0,0 +1,62 @@
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import DefaultDataCollator
class EnWik8Dataset(Dataset):
def __init__(self, data, seq_len):
# pyarrow chunked array
self.data = torch.from_numpy(data)
self.seq_len = seq_len
def __getitem__(self, index):
full_seq = self.data[index].long()
return full_seq.cuda()
def __len__(self):
return len(self.data)
def load_enwik8_dataloader(config, tokenizer):
ds = load_dataset(config["dataset_path"], split="train")
ds = ds.train_test_split(test_size=0.05, seed=config['seed'])
train_ds, val_ds = ds["train"], ds["test"]
keep_cols = ["input_ids"]
train_ds = train_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
train_ds = train_ds.sort("len")
train_ds = train_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
batched=True,
batch_size=config["batch_size"])
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
train_ds = train_ds.remove_columns(remove_cols)
val_ds = val_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
val_ds = val_ds.sort("len")
val_ds = val_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
batched=True,
batch_size=config["batch_size"])
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
val_ds = val_ds.remove_columns(remove_cols)
train_dl = DataLoader(train_ds,
batch_size=config["batch_size"],
shuffle=True,
drop_last=True,
collate_fn=DefaultDataCollator())
val_dl = DataLoader(val_ds,
batch_size=config["batch_size"],
shuffle=True,
drop_last=True,
collate_fn=DefaultDataCollator())
return train_dl, val_dl

View File

@ -1,6 +1,6 @@
import glob import glob
import torch import torch
from datasets import load_dataset from datasets import load_dataset, load_from_disk
import os import os
import hnswlib import hnswlib
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -12,20 +12,23 @@ def load_data(config, tokenizer):
dataset_path = config["dataset_path"] dataset_path = config["dataset_path"]
if os.path.exists(dataset_path): if os.path.exists(dataset_path):
if os.path.isdir(dataset_path): dataset = load_from_disk(dataset_path)
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) # if os.path.isdir(dataset_path):
else: # files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
files = [dataset_path] # else:
# files = [dataset_path]
print(f"Reading files {files}") # print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train") # dataset = load_dataset("json", data_files=files, split="train")
else: else:
dataset = load_dataset(dataset_path, split="train") dataset = load_dataset(dataset_path, split="train")
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
dataset = dataset.map(lambda x: {"prompt": [text + " " + question for text, question in zip(x["text"], x["question"])]}, batched=True)
train_dataset, val_dataset = dataset["train"], dataset["test"] train_dataset, val_dataset = dataset["train"], dataset["test"]
if config["streaming"] is False: if config["streaming"] is False:
@ -33,19 +36,27 @@ def load_data(config, tokenizer):
else: else:
kwargs = {} kwargs = {}
cols_to_keep = ["input_ids", "labels", "attention_mask"]
# tokenize inputs and return labels and attention mask # tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
batched=True, batched=True,
remove_columns=["source", "prompt"], # remove_columns=["source", "prompt"],
**kwargs **kwargs
) )
cols_to_remove = [col for col in train_dataset.column_names if col not in cols_to_keep]
train_dataset = train_dataset.remove_columns(cols_to_remove)
val_dataset = val_dataset.map( val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
batched=True, batched=True,
remove_columns=["source", "prompt"], # remove_columns=["source", "prompt"],
**kwargs **kwargs
) )
cols_to_remove = [col for col in val_dataset.column_names if col not in cols_to_keep]
val_dataset = val_dataset.remove_columns(cols_to_remove)
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch") val_dataset = val_dataset.with_format("torch")
@ -56,12 +67,14 @@ def load_data(config, tokenizer):
train_dataset, train_dataset,
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"], batch_size=config["batch_size"],
shuffle=True,
) )
val_dataloader = DataLoader( val_dataloader = DataLoader(
val_dataset, val_dataset,
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"], batch_size=config["batch_size"],
shuffle=True,
) )
return train_dataloader, val_dataloader return train_dataloader, val_dataloader

View File

@ -5,7 +5,7 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
# hacky backward compatible # hacky backward compatible
different_eos = tokenizer.eos_token != "</s>" different_eos = tokenizer.eos_token != "</s>"
out = {"labels": [], "input_ids": []} out = {"labels": [], "input_ids": [], "attention_mask": []}
for prompt, response in zip(examples[input_col], examples[target_col]): for prompt, response in zip(examples[input_col], examples[target_col]):
if different_eos: if different_eos:
if response.count("</s> \n") > 0: if response.count("</s> \n") > 0:
@ -42,9 +42,10 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
print(response) print(response)
raise raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] tokenized = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)
out["labels"].append(labels) out["labels"].append(labels)
out["input_ids"].append(input_tokens) out["input_ids"].append(tokenized["input_ids"])
out["attention_mask"].append(tokenized["attention_mask"])
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}

View File

@ -41,7 +41,7 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase
if encoder_column != "encoder_hidden_states": if encoder_column != "encoder_hidden_states":
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states") dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"] columns_to_keep = ["input_ids", "attention_mask", "labels", "encoder_hidden_states"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm) dataset = dataset.remove_columns(col_names_to_rm)
@ -115,7 +115,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
**kwargs **kwargs
) )
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"] columns_to_keep = ["id", "input_ids", "attention_mask", "labels", "retrieved_context"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm) dataset = dataset.remove_columns(col_names_to_rm)
@ -128,12 +128,16 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
train_dataset.remove_columns("id"), train_dataset.remove_columns("id"),
batch_size=config["batch_size"], batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
shuffle=True,
drop_last=True,
) )
val_dataloader = DataLoader( val_dataloader = DataLoader(
val_dataset, val_dataset,
batch_size=config["batch_size"], batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
shuffle=True,
drop_last=True,
) )
return train_dataloader, val_dataloader return train_dataloader, val_dataloader

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from gpt4all.models import LetheForCausalLM from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex from gpt4all.models.lethe.modeling_lethe import BatchedMemory
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from gpt4all.train.metrics import f1_score, exact_match_score from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.utils.read import read_config from gpt4all.utils.read import read_config
@ -28,17 +28,21 @@ def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
while True: while True:
if num_new_tokens >= max_new_tokens: if num_new_tokens >= max_new_tokens:
break break
outputs = model(input_ids, save_kv=False) attention_mask = input_ids.ne(tokenizer.pad_token_id)
outputs = model(input_ids, attention_mask=attention_mask, save_kv=False)
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1) next_token_idx = torch.argmax((input_ids == tokenizer.pad_token_id).type(torch.float32))
# -1 because logits at last position predict next token
new_token = torch.argmax(outputs.logits[:, next_token_idx - 1, :], dim=-1)
input_ids[:, next_token_idx] = new_token
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
num_new_tokens += 1 num_new_tokens += 1
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)): if torch.equal(new_token.cpu(), torch.tensor(tokenizer.eos_token_id)):
break break
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True)) print(f"GENERATED: {tokenizer.batch_decode(input_ids, skip_special_tokens=True)}")
return input_ids return input_ids
@ -63,10 +67,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = AutoConfig.from_pretrained(config["model_name"]) model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size, index = BatchedMemory(config["batch_size"],
head_size,
config["num_memories_per_index"], config["num_memories_per_index"],
model_config.num_attention_heads model_config.num_attention_heads,
) )
model = LetheForCausalLM.from_pretrained(config["model_name"], model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None, revision=config['version'] if 'version' in config else None,
memory_attn_layer=config["memory_attn_layer"], memory_attn_layer=config["memory_attn_layer"],
@ -90,16 +95,19 @@ with torch.no_grad():
mem_chunk = memories[chunk_start:chunk_end] mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk.to(device)) model(input_ids=mem_chunk.to(device))
del memories
torch.cuda.empty_cache() torch.cuda.empty_cache()
qa_inputs = batch["input_ids"] qa_inputs = batch["input_ids"]
qa_labels = batch["labels"] qa_labels = batch["labels"]
for i in range(qa_inputs.shape[0]): for i in range(qa_inputs.shape[0]):
inputs = qa_inputs[i].to(device) inputs = qa_inputs[i].to(device)
print(f"EXPECTED: {tokenizer.decode(inputs, skip_special_tokens=True)}")
labels = qa_labels[i].to(device) labels = qa_labels[i].to(device)
cutoff = torch.argmax((labels != -100).type(torch.float32)) cutoff = torch.argmax((labels != -100).type(torch.float32))
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer) inputs[cutoff:] = tokenizer.pad_token_id
print(tokenizer.decode(inputs, skip_special_tokens=True)) greedy_search(inputs.unsqueeze(0).to(device), model, tokenizer)
print(f"CONTEXT: {tokenizer.decode(memories[i], skip_special_tokens=True)}")
import pdb; pdb.set_trace()
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device)) # batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))

View File

@ -0,0 +1,96 @@
{
"builder_name": "squad",
"citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n",
"config_name": "plain_text",
"dataset_size": 89819092,
"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
"download_checksums": {
"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {
"num_bytes": 30288272,
"checksum": null
},
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {
"num_bytes": 4854279,
"checksum": null
}
},
"download_size": 35142551,
"features": {
"id": {
"dtype": "string",
"_type": "Value"
},
"title": {
"dtype": "string",
"_type": "Value"
},
"context": {
"dtype": "string",
"_type": "Value"
},
"question": {
"dtype": "string",
"_type": "Value"
},
"answers": {
"feature": {
"text": {
"dtype": "string",
"_type": "Value"
},
"answer_start": {
"dtype": "int32",
"_type": "Value"
}
},
"_type": "Sequence"
},
"neighbor_ids": {
"feature": {
"dtype": "uint64",
"_type": "Value"
},
"_type": "Sequence"
},
"neighbor_text": {
"feature": {
"dtype": "string",
"_type": "Value"
},
"_type": "Sequence"
},
"loss": {
"dtype": "float64",
"_type": "Value"
}
},
"homepage": "https://rajpurkar.github.io/SQuAD-explorer/",
"license": "",
"size_in_bytes": 124961643,
"splits": {
"train": {
"name": "train",
"num_bytes": 79346108,
"num_examples": 87599,
"dataset_name": "squad"
},
"validation": {
"name": "validation",
"num_bytes": 10472984,
"num_examples": 10570,
"dataset_name": "squad"
}
},
"task_templates": [
{
"task": "question-answering-extractive"
}
],
"version": {
"version_str": "1.0.0",
"description": "",
"major": 1,
"minor": 0,
"patch": 0
}
}

View File

@ -0,0 +1,16 @@
{
"_data_files": [
{
"filename": "data-00000-of-00002.arrow"
},
{
"filename": "data-00001-of-00002.arrow"
}
],
"_fingerprint": "c178f5c8269012a2",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": "validation"
}

View File

@ -0,0 +1,42 @@
import torch
from gpt4all.data.instruction_tuning_dataloader import load_data
from gpt4all.utils.read import read_config
from transformers import AutoTokenizer, AutoModelForCausalLM
from argparse import ArgumentParser
from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_dataloader, val_dataloader = load_data(config, tokenizer)
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.half().to(device)
model.eval()
# Evaluate the model on the SQUAD dataset
f1s = []
exact_matches = []
with torch.no_grad():
for batch in tqdm(val_dataloader):
inputs = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
cutoff = torch.argmax((labels != -100).type(torch.float32))
outputs = model.generate(inputs[:, :cutoff], max_new_tokens=100)
print(f"Predicted: {tokenizer.batch_decode(outputs, skip_special_tokens=True)}")
print(f"Ground truth: {tokenizer.batch_decode(inputs[:, cutoff:], skip_special_tokens=True)}")
print(tokenizer.batch_decode(inputs, skip_special_tokens=True))

View File

@ -2,7 +2,6 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
from .lethe import LetheConfig, LetheForCausalLM from .lethe import LetheConfig, LetheForCausalLM

View File

@ -18,6 +18,8 @@ import wandb
import math import math
import torch.nn.functional as F import torch.nn.functional as F
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@ -121,58 +123,104 @@ class MemoryIndex:
# NOTE: we are storing kv pairs, instead indices for both keys and values # NOTE: we are storing kv pairs, instead indices for both keys and values
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)] self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
shape = (num_mems, nheads, 2, hidden_dim) shape = (nheads, num_mems, 2, hidden_dim)
self.kv_pairs = np.zeros(shape, dtype=np.float32) self.kv_pairs = np.zeros(shape, dtype=np.float32)
self.idx_offset = 0 self.idx_offset = 0
def add(self, keys, values): def add(self, keys, values):
# k/v are (bs, num_attention_heads, seq_len, head_size) # k/v are (num_attention_heads, seq_len, head_size)
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3]) # keys = keys.reshape(keys.shape[1], keys.shape[0], keys.shape[2])
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3]) # values = values.reshape(values.shape[1], values.shape[0], values.shape[2])
for head in range(self.nheads): for head in range(self.nheads):
self.key_indices[head].add(reshaped_keys[:, head, :]) self.key_indices[head].add(keys[head, :, :])
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2) kv_pairs = np.stack((keys, values), axis=2)
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]: if self.idx_offset + kv_pairs.shape[1] > self.kv_pairs.shape[1]:
raise ValueError("Not enough memory!") # reset to 0 to overwrite oldest memories
self.idx_offet = 0
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs self.kv_pairs[:, self.idx_offset:self.idx_offset + kv_pairs.shape[1]] = kv_pairs
self.idx_offset += kv_pairs.shape[0] self.idx_offset += kv_pairs.shape[1]
def knn_query(self, query, k=1): def knn_query(self, query, k=1):
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
mem_keys = [] mem_keys = []
mem_values = [] mem_values = []
mem_indices = [] mem_indices = []
# we can prob make this better # we can prob make this better
for head in range(self.nheads): for head in range(self.nheads):
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k) knn_indices = self.key_indices[head].query(query[head, :, :], k=k)
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices] kv_pairs = self.kv_pairs[head, :, :, :][knn_indices]
mem_keys.append(kv_pairs[:, :, 0, :]) mem_keys.append(kv_pairs[:, :, 0, :])
mem_values.append(kv_pairs[:, :, 1, :]) mem_values.append(kv_pairs[:, :, 1, :])
mem_indices.append(knn_indices) mem_indices.append(knn_indices)
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1)) mem_keys = torch.from_numpy(np.stack(mem_keys, axis=0))
# (bs, num_attention_heads, seq_len, k, head_size) # (num_attention_heads, seq_len, k, head_size)
mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],)) # mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
mem_values = torch.from_numpy(np.stack(mem_values, axis=1)) mem_values = torch.from_numpy(np.stack(mem_values, axis=0))
# (bs, num_attention_heads, seq_len, k, head_size) # (num_attention_heads, seq_len, k, head_size)
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],)) # mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
return mem_keys, mem_values, np.stack(mem_indices, axis=1) return mem_keys, mem_values, np.stack(mem_indices, axis=0)
def reset(self): def reset(self):
for head in range(self.nheads): for head in range(self.nheads):
self.key_indices[head].reset() self.key_indices[head].reset()
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32) self.kv_pairs = np.zeros(self.kv_pairs.shape, dtype=np.float32)
self.idx_offset = 0
class BatchedMemory:
def __init__(self, batch_size, hidden_dim, num_mems, nheads):
self.indices = [MemoryIndex(hidden_dim, num_mems, nheads) for _ in range(batch_size)]
def add(self, keys, values):
for bs in range(len(self.indices)):
self.indices[bs].add(keys[bs], values[bs])
def knn_query(self, query, k=1):
batched_mem_keys = []
batched_mem_values = []
batched_labels = []
for bs in range(len(self.indices)):
knn_keys, knn_values, knn_labels = self.indices[bs].knn_query(query[bs], k=k)
batched_mem_keys.append(knn_keys)
batched_mem_values.append(knn_values)
batched_labels.append(knn_labels)
return torch.stack(batched_mem_keys, dim=0), torch.stack(batched_mem_values, dim=0), np.stack(batched_labels, axis=0)
def reset(self):
for bs in range(len(self.indices)):
self.indices[bs].reset()
class BatchedStorage:
def __init__(self, batch_size, hidden_dim, nheads, seq_len):
self.indices = np.zeros((batch_size, nheads, seq_len, 2, hidden_dim), dtype=np.float32)
def knn_query(self, query, k=None):
return torch.from_numpy(self.indices[:, :, :, 0, :]), torch.from_numpy(self.indices[:, :, :, 1, :])
def add(self, keys, values):
self.indices[:, :, :, 0, :] = keys
self.indices[:, :, :, 1, :] = values
def reset(self):
self.indices = np.zeros_like(self.indices)
class LethePreTrainedModel(PreTrainedModel): class LethePreTrainedModel(PreTrainedModel):
@ -206,12 +254,13 @@ class LethePreTrainedModel(PreTrainedModel):
class LetheAttention(nn.Module): class LetheAttention(nn.Module):
def __init__(self, config, memory_attention=False, index=None, tracker=None): def __init__(self, config, memory_attention=False, index=None, layer_idx=None, tracker=None):
super().__init__() super().__init__()
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
self.layer_idx = layer_idx
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
@ -274,7 +323,6 @@ class LetheAttention(nn.Module):
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# if self.memory:
if self.memory: if self.memory:
# QKNorm: https://arxiv.org/abs/2010.04245 # QKNorm: https://arxiv.org/abs/2010.04245
query = F.normalize(query, dim=-1) query = F.normalize(query, dim=-1)
@ -309,17 +357,28 @@ class LetheAttention(nn.Module):
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy()) self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors) knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
# if log_attn_scores:
# batch_size = query.shape[0]
# seq_len = query.shape[-2]
# key_labels = knn_labels // seq_len
# key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
# correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
# # calculate the accuracy
# key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
# self.tracker.log({"retrieved_acc": key_acc}, step=step)
if log_attn_scores: if log_attn_scores:
batch_size = query.shape[0] total_examples = 0
seq_len = query.shape[-2] unique_examples = 0
for bs in range(query.shape[0]):
for head in range(query.shape[1]):
labels_per_head = knn_labels[bs, head, :, 0].tolist()
total_examples += len(labels_per_head)
unique_examples += len(set(labels_per_head))
key_labels = knn_labels // seq_len self.tracker.log({"unique_retrieved_pct": unique_examples / total_examples}, step=step)
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
# calculate the accuracy
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
self.tracker.log({"retrieved_acc": key_acc}, step=step)
attn_output = self._mem_attn(query, attn_output = self._mem_attn(query,
knn_keys.to(query.device).to(value.dtype), knn_keys.to(query.device).to(value.dtype),
@ -407,6 +466,7 @@ class LetheAttention(nn.Module):
local_attn_scores = local_attn_scores + attention_mask local_attn_scores = local_attn_scores + attention_mask
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key) mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
# mem_attn_scores = torch.matmul(query, knn_key.transpose(-1, -2))
# attn_scores: [bs, seq_len, num_attention_heads, knn] # attn_scores: [bs, seq_len, num_attention_heads, knn]
mem_attn_scores = mem_attn_scores * scale mem_attn_scores = mem_attn_scores * scale
@ -417,48 +477,56 @@ class LetheAttention(nn.Module):
attn_weights = attn_weights.to(local_value.dtype) attn_weights = attn_weights.to(local_value.dtype)
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1) mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
if log_attn_scores: # mem_attn_weights, local_attn_weights = attn_weights.chunk(2, dim=-1)
# (bs, seq_len, num_attention_heads, knn) probabilities
# curate (x,y) pairs
# where x is attention weight, y is accuracy of retrieved token
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
bin_width = 0.05
# Calculate the number of bins # if log_attn_scores:
num_bins = int(1 / bin_width) # # (bs, seq_len, num_attention_heads, knn) probabilities
# # curate (x,y) pairs
# # where x is attention weight, y is accuracy of retrieved token
# bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
# key_labels = knn_labels // seq_len
# key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
# correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
# Create empty lists for storing bin probabilities and accuracies # bin_width = 0.05
bin_probabilities = []
bin_accuracies = []
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist() # # Calculate the number of bins
correct_keys = correct_keys.reshape(-1).tolist() # num_bins = int(1 / bin_width)
# Iterate over each bin # # Create empty lists for storing bin probabilities and accuracies
for i in range(num_bins): # bin_probabilities = []
bin_lower = i * bin_width # bin_accuracies = []
bin_upper = (i + 1) * bin_width # bin_sizes = []
# Filter data points within the current bin range # probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper] # correct_keys = correct_keys.reshape(-1).tolist()
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
# Calculate accuracy for the bin # # Iterate over each bin
total = len(bin_x_values) # for i in range(num_bins):
correct = sum(bin_y_values) # bin_lower = i * bin_width
accuracy = correct / total if total > 0 else 0 # bin_upper = (i + 1) * bin_width
# Store the probability and accuracy for the bin # # Filter data points within the current bin range
bin_probabilities.append((bin_lower + bin_upper) / 2) # bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
bin_accuracies.append(accuracy) # bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)] # # Calculate accuracy for the bin
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"]) # total = len(bin_x_values)
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step) # correct = sum(bin_y_values)
# accuracy = correct / total if total > 0 else 0
# # Store the probability and accuracy for the bin
# bin_probabilities.append((bin_lower + bin_upper) / 2)
# bin_accuracies.append(accuracy)
# bin_sizes.append(len(bin_x_values))
# df = pd.DataFrame({"attn_prob": bin_probabilities, "retrieved_acc": bin_accuracies, "bin_size": bin_sizes})
# fig = px.scatter(df, x="attn_prob", y="retrieved_acc",
# color="bin_size", hover_data=["attn_prob", "retrieved_acc", "bin_size"],
# title="Attention Probability vs Retrieved Accuracy")
# self.tracker.log({"attn_vs_acc": fig}, step=step)
if log_attn_scores: if log_attn_scores:
@ -470,10 +538,10 @@ class LetheAttention(nn.Module):
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1) mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
mem_bins = torch.linspace(0, 1, steps=20 + 1) mem_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(mem_hist.tolist(), mem_bins.tolist()) plt.stairs(mem_hist.tolist(), mem_bins.tolist())
plt.title(f"mem_attn_score_{head}") plt.title(f"mem_attn_score_{head}_layer_{self.layer_idx}")
# set arbitrarily but we want to see those peaks!! # set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000)) plt.ylim((0, 1000))
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step) self.tracker.log({f"mem_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
plt.close() plt.close()
@ -482,15 +550,16 @@ class LetheAttention(nn.Module):
local_hist = torch.histc(local_flat, bins=20, min=0, max=1) local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
local_bins = torch.linspace(0, 1, steps=20 + 1) local_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(local_hist.tolist(), local_bins.tolist()) plt.stairs(local_hist.tolist(), local_bins.tolist())
plt.title(f"local_attn_score_{head}") plt.title(f"local_attn_score_{head}_layer_{self.layer_idx}")
# set arbitrarily but we want to see those peaks!! # set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000)) plt.ylim((0, 1000))
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step) self.tracker.log({f"local_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
plt.close() plt.close()
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size] # attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value) mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
# mem_attn_output = torch.matmul(mem_attn_weights, knn_value)
local_attn_output = torch.matmul(local_attn_weights, local_value) local_attn_output = torch.matmul(local_attn_weights, local_value)
# TODO: do we need flamingo style gating # TODO: do we need flamingo style gating
@ -524,8 +593,6 @@ class LetheAttention(nn.Module):
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor) alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
) )
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
if self.memory:
attn_scores = attn_scores * self.scale.exp()
mask_value = torch.finfo(attn_scores.dtype).min mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
@ -610,12 +677,15 @@ class LetheMLP(nn.Module):
class LetheLayer(nn.Module): class LetheLayer(nn.Module):
def __init__(self, config, memory_attention=False, index=None, tracker=None): def __init__(self, config, memory_attention=False, layer_idx=None, index=None, tracker=None):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker) self.attention = LetheAttention(config, memory_attention=memory_attention,
layer_idx=layer_idx,
index=index[layer_idx] if memory_attention else None,
tracker=tracker)
self.mlp = LetheMLP(config) self.mlp = LetheMLP(config)
def forward( def forward(
@ -676,8 +746,9 @@ class LetheModel(LethePreTrainedModel):
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([LetheLayer(config, self.layers = nn.ModuleList([LetheLayer(config,
memory_attention=i+1 == config.memory_attn_layer, memory_attention=i+1 in config.memory_attn_layer,
index=index if i+1 == config.memory_attn_layer else None, layer_idx=i,
index=index,
tracker=tracker) tracker=tracker)
for i in range(config.num_hidden_layers)]) for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

View File

@ -1,2 +0,0 @@
from .configuration_pythia_retro import PythiaRetroConfig
from .modeling_pythia_retro import PythiaRetroForCausalLM

View File

@ -0,0 +1,288 @@
import os
import torch.nn.functional as F
from transformers import AutoTokenizer, get_scheduler, AutoConfig
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
from gpt4all.utils.read import read_config
from accelerate import Accelerator
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
from gpt4all.data.enwik8 import load_enwik8_dataloader
from torchmetrics import MeanMetric
from tqdm import tqdm
from gpt4all.models import LetheForCausalLM, LetheConfig
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
import wandb
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
return log
def evaluate(model, index, pad_token_id, config, val_dataloader, main_process=False):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
chunk_size = config["seq_len"]
with torch.no_grad():
for batch in tqdm(val_dataloader, disable=not main_process):
seq_len = batch.shape[1]
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = min(seq_len, chunk_start + chunk_size)
inputs = batch[:, chunk_start:chunk_end].to(model.device)
labels = inputs.clone()
outputs = model(input_ids=inputs,
attention_mask=inputs.ne(pad_token_id),
labels=labels,
log_attn_scores=False,
step=None,
save_kv=True,
)
loss = outputs.loss / config["segments"]
loss_values = accelerator.gather_for_metrics({"loss": loss.item()})
val_loss.update(loss_values["loss"])
index.reset()
return val_loss
def train(accelerator, config):
set_seed(config['seed'])
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
train_dataloader, val_dataloader = load_enwik8_dataloader(config, tokenizer)
if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
# instead of decaying to zero, decay to ratio of min_lr / lr
accelerator.print(f"Total training steps: {total_num_steps}")
checkpoint = config["gradient_checkpointing"]
model_config = LetheConfig.from_pretrained(config["model_name"])
model_config.memory_attn_layer = config["memory_attn_layer"]
model_config.num_neighbors_to_retrieve = config["num_neighbors_to_retrieve"]
model_config.use_cache = False if checkpoint else True
head_size = model_config.hidden_size // model_config.num_attention_heads
index = BatchedMemory(config["batch_size"],
head_size,
config["num_memories_per_index"],
model_config.num_attention_heads,
)
model = LetheForCausalLM(model_config,
index=index,
tracker=accelerator.get_tracker("wandb"))
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
if checkpoint:
model.gradient_checkpointing_enable()
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config:
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, scheduler, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, scheduler, train_dataloader, val_dataloader
)
use_scheduler = True
else:
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader
)
use_scheduler = False
# setup for saving training states in case preemption
if use_scheduler:
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
main_process = accelerator.is_main_process
chunk_size = config["seq_len"]
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
epoch_step = epoch * len(train_dataloader) + step * config["segments"]
seq_len = batch["input_ids"].shape[1]
model.train()
for i, chunk_start in enumerate(range(0, seq_len, chunk_size)):
curr_step = epoch_step + i
chunk_end = min(seq_len, chunk_start + chunk_size)
inputs = batch["input_ids"][:, chunk_start:chunk_end]
labels = inputs.clone()
labels[labels == tokenizer.pad_token_id] = -100
outputs = model(input_ids=inputs,
attention_mask=inputs.ne(tokenizer.pad_token_id),
labels=labels,
log_attn_scores=True,
step=curr_step,
save_kv=True,
)
loss = outputs.loss / config["segments"]
if config["wandb"]:
accelerator.log({"loss": loss}, step=curr_step)
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# log LR in case something weird happens
if config["wandb"]:
if step > 0 and step % (config["log_lr_every"] ) == 0:
lr = optimizer.param_groups[0]["lr"]
accelerator.log({"lr": lr}, step=curr_step)
optimizer.step()
if use_scheduler:
scheduler.step()
optimizer.zero_grad()
# reset index on batch end
index.reset()
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/step_{step}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, index, tokenizer.pad_token_id, config, val_dataloader, main_process=main_process)
log_train = {
"train_loss": train_loss.compute()
}
log_val = {
"val_loss": val_loss.compute(),
}
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if config["push_to_hub"]:
accelerator.print(f"Pushing to HF hub")
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()
if __name__ == "__main__":
# parse arguments by reading in a config
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
if config["wandb"]:
accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=config["wandb_project_name"],
config=config,
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
)
else:
accelerator = Accelerator()
train(accelerator, config=config)

View File

@ -116,6 +116,7 @@ def train(accelerator, config):
if config["checkpoint"]: if config["checkpoint"]:
accelerator.load_state(config["checkpoint"]) accelerator.load_state(config["checkpoint"])
import pdb; pdb.set_trace()
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
@ -131,9 +132,12 @@ def train(accelerator, config):
for epoch in range(config["num_epochs"]): for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device) train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)): for step, batch in enumerate(tqdm(train_dataloader)):
curr_step = step + epoch * len(train_dataloader)
model.train() model.train()
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
if config["wandb"]:
accelerator.log({"loss": loss}, step=curr_step)
# gather loss before backprop in case of gradient accumulation # gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
@ -157,7 +161,13 @@ def train(accelerator, config):
if step > 0 and step % config["save_every"] == 0: if step > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader) curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/step_{curr_step}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, val_dataloader) val_loss = evaluate(model, val_dataloader)

View File

@ -11,7 +11,7 @@ from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from torchmetrics import MeanMetric from torchmetrics import MeanMetric
from tqdm import tqdm from tqdm import tqdm
from gpt4all.models import LetheForCausalLM from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex from gpt4all.models.lethe.modeling_lethe import BatchedMemory
import wandb import wandb
import pyarrow as pa import pyarrow as pa
from pyarrow import feather from pyarrow import feather
@ -58,13 +58,15 @@ def evaluate(model, index, config, val_dataloader, main_process=False):
qa_labels = batch["labels"] qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs, outputs = model(input_ids=qa_inputs,
labels=qa_labels, labels=qa_labels,
save_kv=False
) )
del memories del memories
torch.cuda.empty_cache() torch.cuda.empty_cache()
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
index.reset() for ind in index.values():
ind.reset()
val_loss.update(loss_values["loss"]) val_loss.update(loss_values["loss"])
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels) per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
@ -110,16 +112,18 @@ def train(accelerator, config):
model_config = AutoConfig.from_pretrained(config["model_name"]) model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size, indices = {i - 1: BatchedMemory(config["batch_size"],
head_size,
config["num_memories_per_index"], config["num_memories_per_index"],
model_config.num_attention_heads model_config.num_attention_heads,
) ) for i in config["memory_attn_layer"]}
model = LetheForCausalLM.from_pretrained(config["model_name"], model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None, revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True, use_cache=False if checkpoint else True,
memory_attn_layer=config["memory_attn_layer"], memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"], num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index, index=indices,
tracker=accelerator.get_tracker("wandb"), tracker=accelerator.get_tracker("wandb"),
) )
@ -206,8 +210,10 @@ def train(accelerator, config):
model.train() model.train()
qa_inputs = batch["input_ids"] qa_inputs = batch["input_ids"]
attn_mask = batch["attention_mask"]
qa_labels = batch["labels"] qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs, outputs = model(input_ids=qa_inputs,
attention_mask=attn_mask,
labels=qa_labels, labels=qa_labels,
log_attn_scores=True, log_attn_scores=True,
step=curr_step, step=curr_step,
@ -225,6 +231,7 @@ def train(accelerator, config):
loss = loss / gradient_accumulation_steps loss = loss / gradient_accumulation_steps
accelerator.backward(loss) accelerator.backward(loss)
# !! don't reset index until after backwards pass # !! don't reset index until after backwards pass
for index in indices.values():
index.reset() index.reset()
# get gradient norm of all params # get gradient norm of all params
@ -251,7 +258,7 @@ def train(accelerator, config):
) )
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process) val_loss, loss_table = evaluate(model, indices, config, val_dataloader, main_process=main_process)
local_rank = accelerator.process_index local_rank = accelerator.process_index
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather") feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")