mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-13 13:46:08 +00:00
wip
This commit is contained in:
parent
55fef489ad
commit
3128db96ca
@ -14,7 +14,7 @@
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"stage": 1,
|
||||
"offload_param": {
|
||||
"device": "none"
|
||||
},
|
||||
@ -35,5 +35,15 @@
|
||||
],
|
||||
"eps": 1e-08
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
}
|
||||
}
|
@ -1,9 +1,10 @@
|
||||
# model/tokenizer
|
||||
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
|
||||
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/qk_no_norm/step_500"
|
||||
tokenizer_name: "EleutherAI/pythia-1b"
|
||||
version: null
|
||||
gradient_checkpointing: false
|
||||
gradient_checkpointing: true
|
||||
memory_attn_layer: 12
|
||||
seed: 42
|
||||
|
||||
|
||||
# dataset
|
||||
|
@ -5,7 +5,7 @@ version: null
|
||||
gradient_checkpointing: true
|
||||
save_name: "nomic-ai/lethe"
|
||||
push_to_hub: false
|
||||
memory_attn_layer: 12
|
||||
memory_attn_layer: [9, 12, 15]
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
@ -17,7 +17,7 @@ pct_test: 0.05
|
||||
q_column: "question"
|
||||
a_column: "answer"
|
||||
context_column: "text"
|
||||
num_memories_per_index: 2000000
|
||||
num_memories_per_index: 2048
|
||||
num_neighbors_to_retrieve: 2
|
||||
num_neighbors_to_store: 1
|
||||
mem_chunk_size: 64
|
||||
@ -26,15 +26,15 @@ mem_chunk_size: 64
|
||||
lr: 1.0e-5
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 100
|
||||
save_every: 100
|
||||
eval_every: 250
|
||||
save_every: 250
|
||||
log_grads_every: 100
|
||||
log_lr_every: 10
|
||||
output_dir: "ckpts/mem_attn_no_cosine_sim"
|
||||
output_dir: "ckpts/qk_no_norm"
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 200
|
||||
num_epochs: 5
|
||||
num_epochs: 2
|
||||
debug: false
|
||||
scheduler: false
|
||||
|
||||
@ -43,4 +43,3 @@ wandb: true
|
||||
wandb_entity: gpt4all
|
||||
wandb_project_name: mem_attn
|
||||
seed: 42
|
||||
|
||||
|
@ -10,20 +10,21 @@ memory_attn_layer: 12
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "JeanKaddour/minipile"
|
||||
dataset_path: "pg19"
|
||||
max_length: 2048
|
||||
batch_size: 64
|
||||
seq_len: 512
|
||||
segments: 16
|
||||
batch_size: 16
|
||||
pct_test: 0.05
|
||||
num_memories_per_index: 5000000
|
||||
num_memories_per_index: 100000
|
||||
mem_chunk_size: 512
|
||||
num_chunks: 10
|
||||
num_neighbors_to_retrieve: 32
|
||||
|
||||
# train dynamics
|
||||
lr: 1.0e-4
|
||||
lr: 2.0e-4
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 100
|
||||
eval_every: 250
|
||||
save_every: -1
|
||||
log_grads_every: 100
|
||||
log_lr_every: 10
|
||||
@ -38,6 +39,6 @@ scheduler: false
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: gpt4all
|
||||
wandb_project_name: minipile
|
||||
wandb_project_name: enwik8
|
||||
seed: 42
|
||||
|
62
gpt4all/data/enwik8.py
Normal file
62
gpt4all/data/enwik8.py
Normal file
@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import DefaultDataCollator
|
||||
|
||||
|
||||
|
||||
class EnWik8Dataset(Dataset):
|
||||
def __init__(self, data, seq_len):
|
||||
# pyarrow chunked array
|
||||
self.data = torch.from_numpy(data)
|
||||
self.seq_len = seq_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
full_seq = self.data[index].long()
|
||||
return full_seq.cuda()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def load_enwik8_dataloader(config, tokenizer):
|
||||
ds = load_dataset(config["dataset_path"], split="train")
|
||||
|
||||
ds = ds.train_test_split(test_size=0.05, seed=config['seed'])
|
||||
|
||||
train_ds, val_ds = ds["train"], ds["test"]
|
||||
|
||||
keep_cols = ["input_ids"]
|
||||
|
||||
train_ds = train_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
|
||||
train_ds = train_ds.sort("len")
|
||||
train_ds = train_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
|
||||
batched=True,
|
||||
batch_size=config["batch_size"])
|
||||
|
||||
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
|
||||
train_ds = train_ds.remove_columns(remove_cols)
|
||||
|
||||
val_ds = val_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
|
||||
val_ds = val_ds.sort("len")
|
||||
val_ds = val_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
|
||||
batched=True,
|
||||
batch_size=config["batch_size"])
|
||||
|
||||
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
|
||||
val_ds = val_ds.remove_columns(remove_cols)
|
||||
|
||||
train_dl = DataLoader(train_ds,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=DefaultDataCollator())
|
||||
|
||||
val_dl = DataLoader(val_ds,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=DefaultDataCollator())
|
||||
|
||||
return train_dl, val_dl
|
@ -1,6 +1,6 @@
|
||||
import glob
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
import os
|
||||
import hnswlib
|
||||
from torch.utils.data import DataLoader
|
||||
@ -12,20 +12,23 @@ def load_data(config, tokenizer):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
if os.path.isdir(dataset_path):
|
||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||
else:
|
||||
files = [dataset_path]
|
||||
dataset = load_from_disk(dataset_path)
|
||||
# if os.path.isdir(dataset_path):
|
||||
# files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||
# else:
|
||||
# files = [dataset_path]
|
||||
|
||||
print(f"Reading files {files}")
|
||||
# print(f"Reading files {files}")
|
||||
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
# dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split="train")
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
dataset = dataset.map(lambda x: {"prompt": [text + " " + question for text, question in zip(x["text"], x["question"])]}, batched=True)
|
||||
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
if config["streaming"] is False:
|
||||
@ -33,19 +36,27 @@ def load_data(config, tokenizer):
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
cols_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
|
||||
# tokenize inputs and return labels and attention mask
|
||||
train_dataset = train_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
# remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
cols_to_remove = [col for col in train_dataset.column_names if col not in cols_to_keep]
|
||||
train_dataset = train_dataset.remove_columns(cols_to_remove)
|
||||
|
||||
val_dataset = val_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
# remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
cols_to_remove = [col for col in val_dataset.column_names if col not in cols_to_keep]
|
||||
val_dataset = val_dataset.remove_columns(cols_to_remove)
|
||||
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
val_dataset = val_dataset.with_format("torch")
|
||||
@ -56,12 +67,14 @@ def load_data(config, tokenizer):
|
||||
train_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
@ -5,7 +5,7 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
|
||||
|
||||
# hacky backward compatible
|
||||
different_eos = tokenizer.eos_token != "</s>"
|
||||
out = {"labels": [], "input_ids": []}
|
||||
out = {"labels": [], "input_ids": [], "attention_mask": []}
|
||||
for prompt, response in zip(examples[input_col], examples[target_col]):
|
||||
if different_eos:
|
||||
if response.count("</s> \n") > 0:
|
||||
@ -42,9 +42,10 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
|
||||
print(response)
|
||||
raise
|
||||
|
||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
||||
tokenized = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)
|
||||
out["labels"].append(labels)
|
||||
out["input_ids"].append(input_tokens)
|
||||
out["input_ids"].append(tokenized["input_ids"])
|
||||
out["attention_mask"].append(tokenized["attention_mask"])
|
||||
|
||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||
|
||||
|
@ -41,7 +41,7 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase
|
||||
if encoder_column != "encoder_hidden_states":
|
||||
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
|
||||
|
||||
columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"]
|
||||
columns_to_keep = ["input_ids", "attention_mask", "labels", "encoder_hidden_states"]
|
||||
|
||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||
dataset = dataset.remove_columns(col_names_to_rm)
|
||||
@ -115,7 +115,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
||||
**kwargs
|
||||
)
|
||||
|
||||
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
|
||||
columns_to_keep = ["id", "input_ids", "attention_mask", "labels", "retrieved_context"]
|
||||
|
||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||
dataset = dataset.remove_columns(col_names_to_rm)
|
||||
@ -128,12 +128,16 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
||||
train_dataset.remove_columns("id"),
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gpt4all.models import LetheForCausalLM
|
||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
|
||||
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
||||
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||
from gpt4all.utils.read import read_config
|
||||
@ -28,17 +28,21 @@ def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
|
||||
while True:
|
||||
if num_new_tokens >= max_new_tokens:
|
||||
break
|
||||
outputs = model(input_ids, save_kv=False)
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, save_kv=False)
|
||||
|
||||
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
|
||||
next_token_idx = torch.argmax((input_ids == tokenizer.pad_token_id).type(torch.float32))
|
||||
# -1 because logits at last position predict next token
|
||||
new_token = torch.argmax(outputs.logits[:, next_token_idx - 1, :], dim=-1)
|
||||
|
||||
input_ids[:, next_token_idx] = new_token
|
||||
|
||||
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
|
||||
num_new_tokens += 1
|
||||
|
||||
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
|
||||
if torch.equal(new_token.cpu(), torch.tensor(tokenizer.eos_token_id)):
|
||||
break
|
||||
|
||||
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
|
||||
print(f"GENERATED: {tokenizer.batch_decode(input_ids, skip_special_tokens=True)}")
|
||||
|
||||
return input_ids
|
||||
|
||||
@ -63,10 +67,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model_config = AutoConfig.from_pretrained(config["model_name"])
|
||||
|
||||
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||
index = MemoryIndex(head_size,
|
||||
index = BatchedMemory(config["batch_size"],
|
||||
head_size,
|
||||
config["num_memories_per_index"],
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
model_config.num_attention_heads,
|
||||
)
|
||||
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
||||
revision=config['version'] if 'version' in config else None,
|
||||
memory_attn_layer=config["memory_attn_layer"],
|
||||
@ -90,16 +95,19 @@ with torch.no_grad():
|
||||
mem_chunk = memories[chunk_start:chunk_end]
|
||||
model(input_ids=mem_chunk.to(device))
|
||||
|
||||
del memories
|
||||
torch.cuda.empty_cache()
|
||||
qa_inputs = batch["input_ids"]
|
||||
qa_labels = batch["labels"]
|
||||
for i in range(qa_inputs.shape[0]):
|
||||
inputs = qa_inputs[i].to(device)
|
||||
print(f"EXPECTED: {tokenizer.decode(inputs, skip_special_tokens=True)}")
|
||||
labels = qa_labels[i].to(device)
|
||||
|
||||
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
||||
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
|
||||
print(tokenizer.decode(inputs, skip_special_tokens=True))
|
||||
inputs[cutoff:] = tokenizer.pad_token_id
|
||||
greedy_search(inputs.unsqueeze(0).to(device), model, tokenizer)
|
||||
print(f"CONTEXT: {tokenizer.decode(memories[i], skip_special_tokens=True)}")
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
|
||||
|
96
gpt4all/eval/eval_squad_atlas_map/dataset_info.json
Normal file
96
gpt4all/eval/eval_squad_atlas_map/dataset_info.json
Normal file
@ -0,0 +1,96 @@
|
||||
{
|
||||
"builder_name": "squad",
|
||||
"citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n",
|
||||
"config_name": "plain_text",
|
||||
"dataset_size": 89819092,
|
||||
"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
|
||||
"download_checksums": {
|
||||
"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {
|
||||
"num_bytes": 30288272,
|
||||
"checksum": null
|
||||
},
|
||||
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {
|
||||
"num_bytes": 4854279,
|
||||
"checksum": null
|
||||
}
|
||||
},
|
||||
"download_size": 35142551,
|
||||
"features": {
|
||||
"id": {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
},
|
||||
"title": {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
},
|
||||
"context": {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
},
|
||||
"question": {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
},
|
||||
"answers": {
|
||||
"feature": {
|
||||
"text": {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
},
|
||||
"answer_start": {
|
||||
"dtype": "int32",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"neighbor_ids": {
|
||||
"feature": {
|
||||
"dtype": "uint64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"neighbor_text": {
|
||||
"feature": {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
},
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"loss": {
|
||||
"dtype": "float64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "https://rajpurkar.github.io/SQuAD-explorer/",
|
||||
"license": "",
|
||||
"size_in_bytes": 124961643,
|
||||
"splits": {
|
||||
"train": {
|
||||
"name": "train",
|
||||
"num_bytes": 79346108,
|
||||
"num_examples": 87599,
|
||||
"dataset_name": "squad"
|
||||
},
|
||||
"validation": {
|
||||
"name": "validation",
|
||||
"num_bytes": 10472984,
|
||||
"num_examples": 10570,
|
||||
"dataset_name": "squad"
|
||||
}
|
||||
},
|
||||
"task_templates": [
|
||||
{
|
||||
"task": "question-answering-extractive"
|
||||
}
|
||||
],
|
||||
"version": {
|
||||
"version_str": "1.0.0",
|
||||
"description": "",
|
||||
"major": 1,
|
||||
"minor": 0,
|
||||
"patch": 0
|
||||
}
|
||||
}
|
16
gpt4all/eval/eval_squad_atlas_map/state.json
Normal file
16
gpt4all/eval/eval_squad_atlas_map/state.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00002.arrow"
|
||||
},
|
||||
{
|
||||
"filename": "data-00001-of-00002.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "c178f5c8269012a2",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": null,
|
||||
"_output_all_columns": false,
|
||||
"_split": "validation"
|
||||
}
|
42
gpt4all/eval/eval_synthetic.py
Normal file
42
gpt4all/eval/eval_synthetic.py
Normal file
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from gpt4all.data.instruction_tuning_dataloader import load_data
|
||||
from gpt4all.utils.read import read_config
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from argparse import ArgumentParser
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
trust_remote_code=True)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = model.half().to(device)
|
||||
model.eval()
|
||||
|
||||
# Evaluate the model on the SQUAD dataset
|
||||
f1s = []
|
||||
exact_matches = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_dataloader):
|
||||
inputs = batch["input_ids"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
||||
outputs = model.generate(inputs[:, :cutoff], max_new_tokens=100)
|
||||
print(f"Predicted: {tokenizer.batch_decode(outputs, skip_special_tokens=True)}")
|
||||
print(f"Ground truth: {tokenizer.batch_decode(inputs[:, cutoff:], skip_special_tokens=True)}")
|
||||
print(tokenizer.batch_decode(inputs, skip_special_tokens=True))
|
||||
|
@ -2,7 +2,6 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
|
||||
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
||||
|
||||
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
||||
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
|
||||
from .lethe import LetheConfig, LetheForCausalLM
|
||||
|
||||
|
||||
|
@ -18,6 +18,8 @@ import wandb
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
import plotly.express as px
|
||||
import pandas as pd
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -121,58 +123,104 @@ class MemoryIndex:
|
||||
# NOTE: we are storing kv pairs, instead indices for both keys and values
|
||||
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
||||
|
||||
shape = (num_mems, nheads, 2, hidden_dim)
|
||||
shape = (nheads, num_mems, 2, hidden_dim)
|
||||
self.kv_pairs = np.zeros(shape, dtype=np.float32)
|
||||
self.idx_offset = 0
|
||||
|
||||
def add(self, keys, values):
|
||||
# k/v are (bs, num_attention_heads, seq_len, head_size)
|
||||
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
|
||||
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
|
||||
# k/v are (num_attention_heads, seq_len, head_size)
|
||||
# keys = keys.reshape(keys.shape[1], keys.shape[0], keys.shape[2])
|
||||
# values = values.reshape(values.shape[1], values.shape[0], values.shape[2])
|
||||
|
||||
for head in range(self.nheads):
|
||||
self.key_indices[head].add(reshaped_keys[:, head, :])
|
||||
self.key_indices[head].add(keys[head, :, :])
|
||||
|
||||
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
|
||||
kv_pairs = np.stack((keys, values), axis=2)
|
||||
|
||||
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
|
||||
raise ValueError("Not enough memory!")
|
||||
if self.idx_offset + kv_pairs.shape[1] > self.kv_pairs.shape[1]:
|
||||
# reset to 0 to overwrite oldest memories
|
||||
self.idx_offet = 0
|
||||
|
||||
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
|
||||
self.idx_offset += kv_pairs.shape[0]
|
||||
self.kv_pairs[:, self.idx_offset:self.idx_offset + kv_pairs.shape[1]] = kv_pairs
|
||||
self.idx_offset += kv_pairs.shape[1]
|
||||
|
||||
def knn_query(self, query, k=1):
|
||||
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
|
||||
|
||||
mem_keys = []
|
||||
mem_values = []
|
||||
mem_indices = []
|
||||
|
||||
# we can prob make this better
|
||||
for head in range(self.nheads):
|
||||
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
|
||||
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
|
||||
knn_indices = self.key_indices[head].query(query[head, :, :], k=k)
|
||||
kv_pairs = self.kv_pairs[head, :, :, :][knn_indices]
|
||||
|
||||
mem_keys.append(kv_pairs[:, :, 0, :])
|
||||
mem_values.append(kv_pairs[:, :, 1, :])
|
||||
mem_indices.append(knn_indices)
|
||||
|
||||
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
|
||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
||||
mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=0))
|
||||
# (num_attention_heads, seq_len, k, head_size)
|
||||
# mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||
|
||||
mem_values = torch.from_numpy(np.stack(mem_values, axis=1))
|
||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
||||
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||
mem_values = torch.from_numpy(np.stack(mem_values, axis=0))
|
||||
# (num_attention_heads, seq_len, k, head_size)
|
||||
# mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||
|
||||
return mem_keys, mem_values, np.stack(mem_indices, axis=1)
|
||||
return mem_keys, mem_values, np.stack(mem_indices, axis=0)
|
||||
|
||||
|
||||
def reset(self):
|
||||
for head in range(self.nheads):
|
||||
self.key_indices[head].reset()
|
||||
|
||||
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
|
||||
self.kv_pairs = np.zeros(self.kv_pairs.shape, dtype=np.float32)
|
||||
self.idx_offset = 0
|
||||
|
||||
|
||||
class BatchedMemory:
|
||||
def __init__(self, batch_size, hidden_dim, num_mems, nheads):
|
||||
self.indices = [MemoryIndex(hidden_dim, num_mems, nheads) for _ in range(batch_size)]
|
||||
|
||||
|
||||
def add(self, keys, values):
|
||||
for bs in range(len(self.indices)):
|
||||
self.indices[bs].add(keys[bs], values[bs])
|
||||
|
||||
|
||||
def knn_query(self, query, k=1):
|
||||
batched_mem_keys = []
|
||||
batched_mem_values = []
|
||||
batched_labels = []
|
||||
|
||||
for bs in range(len(self.indices)):
|
||||
knn_keys, knn_values, knn_labels = self.indices[bs].knn_query(query[bs], k=k)
|
||||
batched_mem_keys.append(knn_keys)
|
||||
batched_mem_values.append(knn_values)
|
||||
batched_labels.append(knn_labels)
|
||||
|
||||
|
||||
return torch.stack(batched_mem_keys, dim=0), torch.stack(batched_mem_values, dim=0), np.stack(batched_labels, axis=0)
|
||||
|
||||
def reset(self):
|
||||
for bs in range(len(self.indices)):
|
||||
self.indices[bs].reset()
|
||||
|
||||
|
||||
class BatchedStorage:
|
||||
def __init__(self, batch_size, hidden_dim, nheads, seq_len):
|
||||
self.indices = np.zeros((batch_size, nheads, seq_len, 2, hidden_dim), dtype=np.float32)
|
||||
|
||||
def knn_query(self, query, k=None):
|
||||
return torch.from_numpy(self.indices[:, :, :, 0, :]), torch.from_numpy(self.indices[:, :, :, 1, :])
|
||||
|
||||
|
||||
def add(self, keys, values):
|
||||
self.indices[:, :, :, 0, :] = keys
|
||||
self.indices[:, :, :, 1, :] = values
|
||||
|
||||
def reset(self):
|
||||
self.indices = np.zeros_like(self.indices)
|
||||
|
||||
|
||||
|
||||
class LethePreTrainedModel(PreTrainedModel):
|
||||
@ -206,12 +254,13 @@ class LethePreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class LetheAttention(nn.Module):
|
||||
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
||||
def __init__(self, config, memory_attention=False, index=None, layer_idx=None, tracker=None):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
self.layer_idx = layer_idx
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
@ -274,7 +323,6 @@ class LetheAttention(nn.Module):
|
||||
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
||||
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
||||
|
||||
# if self.memory:
|
||||
if self.memory:
|
||||
# QKNorm: https://arxiv.org/abs/2010.04245
|
||||
query = F.normalize(query, dim=-1)
|
||||
@ -309,17 +357,28 @@ class LetheAttention(nn.Module):
|
||||
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||
|
||||
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||
# if log_attn_scores:
|
||||
# batch_size = query.shape[0]
|
||||
# seq_len = query.shape[-2]
|
||||
|
||||
# key_labels = knn_labels // seq_len
|
||||
# key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
|
||||
# correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||
# # calculate the accuracy
|
||||
# key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
|
||||
|
||||
# self.tracker.log({"retrieved_acc": key_acc}, step=step)
|
||||
|
||||
if log_attn_scores:
|
||||
batch_size = query.shape[0]
|
||||
seq_len = query.shape[-2]
|
||||
total_examples = 0
|
||||
unique_examples = 0
|
||||
for bs in range(query.shape[0]):
|
||||
for head in range(query.shape[1]):
|
||||
labels_per_head = knn_labels[bs, head, :, 0].tolist()
|
||||
total_examples += len(labels_per_head)
|
||||
unique_examples += len(set(labels_per_head))
|
||||
|
||||
key_labels = knn_labels // seq_len
|
||||
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
|
||||
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||
# calculate the accuracy
|
||||
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
|
||||
|
||||
self.tracker.log({"retrieved_acc": key_acc}, step=step)
|
||||
self.tracker.log({"unique_retrieved_pct": unique_examples / total_examples}, step=step)
|
||||
|
||||
attn_output = self._mem_attn(query,
|
||||
knn_keys.to(query.device).to(value.dtype),
|
||||
@ -407,6 +466,7 @@ class LetheAttention(nn.Module):
|
||||
local_attn_scores = local_attn_scores + attention_mask
|
||||
|
||||
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
|
||||
# mem_attn_scores = torch.matmul(query, knn_key.transpose(-1, -2))
|
||||
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
||||
mem_attn_scores = mem_attn_scores * scale
|
||||
|
||||
@ -417,48 +477,56 @@ class LetheAttention(nn.Module):
|
||||
attn_weights = attn_weights.to(local_value.dtype)
|
||||
|
||||
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
|
||||
if log_attn_scores:
|
||||
# (bs, seq_len, num_attention_heads, knn) probabilities
|
||||
# curate (x,y) pairs
|
||||
# where x is attention weight, y is accuracy of retrieved token
|
||||
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
|
||||
key_labels = knn_labels // seq_len
|
||||
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
|
||||
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||
# mem_attn_weights, local_attn_weights = attn_weights.chunk(2, dim=-1)
|
||||
|
||||
bin_width = 0.05
|
||||
|
||||
# Calculate the number of bins
|
||||
num_bins = int(1 / bin_width)
|
||||
# if log_attn_scores:
|
||||
# # (bs, seq_len, num_attention_heads, knn) probabilities
|
||||
# # curate (x,y) pairs
|
||||
# # where x is attention weight, y is accuracy of retrieved token
|
||||
# bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
|
||||
# key_labels = knn_labels // seq_len
|
||||
# key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
|
||||
# correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||
|
||||
# Create empty lists for storing bin probabilities and accuracies
|
||||
bin_probabilities = []
|
||||
bin_accuracies = []
|
||||
# bin_width = 0.05
|
||||
|
||||
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
|
||||
correct_keys = correct_keys.reshape(-1).tolist()
|
||||
# # Calculate the number of bins
|
||||
# num_bins = int(1 / bin_width)
|
||||
|
||||
# Iterate over each bin
|
||||
for i in range(num_bins):
|
||||
bin_lower = i * bin_width
|
||||
bin_upper = (i + 1) * bin_width
|
||||
# # Create empty lists for storing bin probabilities and accuracies
|
||||
# bin_probabilities = []
|
||||
# bin_accuracies = []
|
||||
# bin_sizes = []
|
||||
|
||||
# Filter data points within the current bin range
|
||||
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
|
||||
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
|
||||
# probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
|
||||
# correct_keys = correct_keys.reshape(-1).tolist()
|
||||
|
||||
# Calculate accuracy for the bin
|
||||
total = len(bin_x_values)
|
||||
correct = sum(bin_y_values)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
# # Iterate over each bin
|
||||
# for i in range(num_bins):
|
||||
# bin_lower = i * bin_width
|
||||
# bin_upper = (i + 1) * bin_width
|
||||
|
||||
# Store the probability and accuracy for the bin
|
||||
bin_probabilities.append((bin_lower + bin_upper) / 2)
|
||||
bin_accuracies.append(accuracy)
|
||||
# # Filter data points within the current bin range
|
||||
# bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
|
||||
# bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
|
||||
|
||||
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
|
||||
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
|
||||
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
|
||||
# # Calculate accuracy for the bin
|
||||
# total = len(bin_x_values)
|
||||
# correct = sum(bin_y_values)
|
||||
# accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# # Store the probability and accuracy for the bin
|
||||
# bin_probabilities.append((bin_lower + bin_upper) / 2)
|
||||
# bin_accuracies.append(accuracy)
|
||||
# bin_sizes.append(len(bin_x_values))
|
||||
|
||||
# df = pd.DataFrame({"attn_prob": bin_probabilities, "retrieved_acc": bin_accuracies, "bin_size": bin_sizes})
|
||||
|
||||
# fig = px.scatter(df, x="attn_prob", y="retrieved_acc",
|
||||
# color="bin_size", hover_data=["attn_prob", "retrieved_acc", "bin_size"],
|
||||
# title="Attention Probability vs Retrieved Accuracy")
|
||||
# self.tracker.log({"attn_vs_acc": fig}, step=step)
|
||||
|
||||
|
||||
if log_attn_scores:
|
||||
@ -470,10 +538,10 @@ class LetheAttention(nn.Module):
|
||||
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
|
||||
mem_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
|
||||
plt.title(f"mem_attn_score_{head}")
|
||||
plt.title(f"mem_attn_score_{head}_layer_{self.layer_idx}")
|
||||
# set arbitrarily but we want to see those peaks!!
|
||||
plt.ylim((0, 1000))
|
||||
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
|
||||
self.tracker.log({f"mem_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
|
||||
plt.close()
|
||||
|
||||
|
||||
@ -482,15 +550,16 @@ class LetheAttention(nn.Module):
|
||||
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
|
||||
local_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||
plt.stairs(local_hist.tolist(), local_bins.tolist())
|
||||
plt.title(f"local_attn_score_{head}")
|
||||
plt.title(f"local_attn_score_{head}_layer_{self.layer_idx}")
|
||||
# set arbitrarily but we want to see those peaks!!
|
||||
plt.ylim((0, 1000))
|
||||
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
|
||||
self.tracker.log({f"local_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
|
||||
plt.close()
|
||||
|
||||
|
||||
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
|
||||
# mem_attn_output = torch.matmul(mem_attn_weights, knn_value)
|
||||
local_attn_output = torch.matmul(local_attn_weights, local_value)
|
||||
|
||||
# TODO: do we need flamingo style gating
|
||||
@ -524,8 +593,6 @@ class LetheAttention(nn.Module):
|
||||
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
|
||||
)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
if self.memory:
|
||||
attn_scores = attn_scores * self.scale.exp()
|
||||
|
||||
mask_value = torch.finfo(attn_scores.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
@ -610,12 +677,15 @@ class LetheMLP(nn.Module):
|
||||
|
||||
|
||||
class LetheLayer(nn.Module):
|
||||
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
||||
def __init__(self, config, memory_attention=False, layer_idx=None, index=None, tracker=None):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
|
||||
self.attention = LetheAttention(config, memory_attention=memory_attention,
|
||||
layer_idx=layer_idx,
|
||||
index=index[layer_idx] if memory_attention else None,
|
||||
tracker=tracker)
|
||||
self.mlp = LetheMLP(config)
|
||||
|
||||
def forward(
|
||||
@ -676,8 +746,9 @@ class LetheModel(LethePreTrainedModel):
|
||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.layers = nn.ModuleList([LetheLayer(config,
|
||||
memory_attention=i+1 == config.memory_attn_layer,
|
||||
index=index if i+1 == config.memory_attn_layer else None,
|
||||
memory_attention=i+1 in config.memory_attn_layer,
|
||||
layer_idx=i,
|
||||
index=index,
|
||||
tracker=tracker)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
@ -1,2 +0,0 @@
|
||||
from .configuration_pythia_retro import PythiaRetroConfig
|
||||
from .modeling_pythia_retro import PythiaRetroForCausalLM
|
288
gpt4all/train/pretrain_mem_retrieval.py
Normal file
288
gpt4all/train/pretrain_mem_retrieval.py
Normal file
@ -0,0 +1,288 @@
|
||||
import os
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer, get_scheduler, AutoConfig
|
||||
import torch
|
||||
from torch.optim import AdamW
|
||||
from argparse import ArgumentParser
|
||||
from gpt4all.utils.read import read_config
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||
from gpt4all.data.enwik8 import load_enwik8_dataloader
|
||||
from torchmetrics import MeanMetric
|
||||
from tqdm import tqdm
|
||||
from gpt4all.models import LetheForCausalLM, LetheConfig
|
||||
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
|
||||
import wandb
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
def format_metrics(metrics, split, prefix=""):
|
||||
log = f"[{split}]" + prefix
|
||||
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
|
||||
|
||||
return log
|
||||
|
||||
|
||||
def evaluate(model, index, pad_token_id, config, val_dataloader, main_process=False):
|
||||
model.eval()
|
||||
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
|
||||
chunk_size = config["seq_len"]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_dataloader, disable=not main_process):
|
||||
seq_len = batch.shape[1]
|
||||
for chunk_start in range(0, seq_len, chunk_size):
|
||||
chunk_end = min(seq_len, chunk_start + chunk_size)
|
||||
inputs = batch[:, chunk_start:chunk_end].to(model.device)
|
||||
labels = inputs.clone()
|
||||
outputs = model(input_ids=inputs,
|
||||
attention_mask=inputs.ne(pad_token_id),
|
||||
labels=labels,
|
||||
log_attn_scores=False,
|
||||
step=None,
|
||||
save_kv=True,
|
||||
)
|
||||
loss = outputs.loss / config["segments"]
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.item()})
|
||||
val_loss.update(loss_values["loss"])
|
||||
|
||||
index.reset()
|
||||
|
||||
return val_loss
|
||||
|
||||
|
||||
def train(accelerator, config):
|
||||
set_seed(config['seed'])
|
||||
|
||||
accelerator.print(config)
|
||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||
# if no pad token, set it to eos
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
with accelerator.main_process_first():
|
||||
train_dataloader, val_dataloader = load_enwik8_dataloader(config, tokenizer)
|
||||
|
||||
|
||||
if accelerator.state.deepspeed_plugin is not None:
|
||||
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
|
||||
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
|
||||
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
|
||||
# instead of decaying to zero, decay to ratio of min_lr / lr
|
||||
accelerator.print(f"Total training steps: {total_num_steps}")
|
||||
|
||||
checkpoint = config["gradient_checkpointing"]
|
||||
|
||||
model_config = LetheConfig.from_pretrained(config["model_name"])
|
||||
model_config.memory_attn_layer = config["memory_attn_layer"]
|
||||
model_config.num_neighbors_to_retrieve = config["num_neighbors_to_retrieve"]
|
||||
model_config.use_cache = False if checkpoint else True
|
||||
|
||||
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||
index = BatchedMemory(config["batch_size"],
|
||||
head_size,
|
||||
config["num_memories_per_index"],
|
||||
model_config.num_attention_heads,
|
||||
)
|
||||
|
||||
model = LetheForCausalLM(model_config,
|
||||
index=index,
|
||||
tracker=accelerator.get_tracker("wandb"))
|
||||
|
||||
|
||||
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
optimizer_cls = (
|
||||
AdamW
|
||||
if accelerator.state.deepspeed_plugin is None
|
||||
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
else DummyOptim
|
||||
)
|
||||
|
||||
# karpathy doesn't decay embeddding, maybe we should exclude
|
||||
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
||||
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
|
||||
|
||||
|
||||
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
|
||||
if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config:
|
||||
if (
|
||||
accelerator.state.deepspeed_plugin is None
|
||||
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
scheduler = get_scheduler(
|
||||
name="cosine",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
|
||||
num_training_steps=total_num_steps,
|
||||
)
|
||||
else:
|
||||
scheduler = DummyScheduler(
|
||||
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
|
||||
)
|
||||
model, optimizer, scheduler, train_dataloader, val_dataloader = accelerator.prepare(
|
||||
model, optimizer, scheduler, train_dataloader, val_dataloader
|
||||
)
|
||||
use_scheduler = True
|
||||
else:
|
||||
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, val_dataloader
|
||||
)
|
||||
use_scheduler = False
|
||||
|
||||
# setup for saving training states in case preemption
|
||||
if use_scheduler:
|
||||
accelerator.register_for_checkpointing(scheduler)
|
||||
|
||||
if config["checkpoint"]:
|
||||
accelerator.load_state(config["checkpoint"])
|
||||
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
resume_step = int(training_difference.replace("step_", ""))
|
||||
accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
accelerator.print(f"Resuming from step {resume_step}")
|
||||
|
||||
# log gradients
|
||||
if accelerator.is_main_process and config["wandb"]:
|
||||
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||
|
||||
main_process = accelerator.is_main_process
|
||||
|
||||
chunk_size = config["seq_len"]
|
||||
for epoch in range(config["num_epochs"]):
|
||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||
epoch_step = epoch * len(train_dataloader) + step * config["segments"]
|
||||
seq_len = batch["input_ids"].shape[1]
|
||||
model.train()
|
||||
for i, chunk_start in enumerate(range(0, seq_len, chunk_size)):
|
||||
curr_step = epoch_step + i
|
||||
chunk_end = min(seq_len, chunk_start + chunk_size)
|
||||
inputs = batch["input_ids"][:, chunk_start:chunk_end]
|
||||
labels = inputs.clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
outputs = model(input_ids=inputs,
|
||||
attention_mask=inputs.ne(tokenizer.pad_token_id),
|
||||
labels=labels,
|
||||
log_attn_scores=True,
|
||||
step=curr_step,
|
||||
save_kv=True,
|
||||
)
|
||||
loss = outputs.loss / config["segments"]
|
||||
|
||||
if config["wandb"]:
|
||||
accelerator.log({"loss": loss}, step=curr_step)
|
||||
|
||||
# gather loss before backprop in case of gradient accumulation
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||
train_loss.update(loss_values["loss"])
|
||||
|
||||
loss = loss / gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# log LR in case something weird happens
|
||||
if config["wandb"]:
|
||||
if step > 0 and step % (config["log_lr_every"] ) == 0:
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
accelerator.log({"lr": lr}, step=curr_step)
|
||||
|
||||
optimizer.step()
|
||||
if use_scheduler:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# reset index on batch end
|
||||
index.reset()
|
||||
|
||||
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
||||
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/step_{step}",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||
val_loss = evaluate(model, index, tokenizer.pad_token_id, config, val_dataloader, main_process=main_process)
|
||||
|
||||
log_train = {
|
||||
"train_loss": train_loss.compute()
|
||||
}
|
||||
log_val = {
|
||||
"val_loss": val_loss.compute(),
|
||||
}
|
||||
|
||||
if config["wandb"]:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
accelerator.log({**log_train, **log_val}, step=curr_step)
|
||||
|
||||
accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}")
|
||||
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
|
||||
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
|
||||
|
||||
train_loss.reset()
|
||||
|
||||
accelerator.print(f"Epoch {epoch} finished")
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
if config["push_to_hub"]:
|
||||
accelerator.print(f"Pushing to HF hub")
|
||||
try:
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
||||
|
||||
except Exception as e:
|
||||
accelerator.print(e)
|
||||
accelerator.print(f"Failed to push to hub")
|
||||
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/epoch_{epoch}",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/final",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
if config["wandb"]:
|
||||
accelerator = Accelerator(log_with="wandb")
|
||||
accelerator.init_trackers(
|
||||
project_name=config["wandb_project_name"],
|
||||
config=config,
|
||||
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator()
|
||||
|
||||
train(accelerator, config=config)
|
@ -116,6 +116,7 @@ def train(accelerator, config):
|
||||
|
||||
if config["checkpoint"]:
|
||||
accelerator.load_state(config["checkpoint"])
|
||||
import pdb; pdb.set_trace()
|
||||
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
@ -131,9 +132,12 @@ def train(accelerator, config):
|
||||
for epoch in range(config["num_epochs"]):
|
||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
model.train()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
if config["wandb"]:
|
||||
accelerator.log({"loss": loss}, step=curr_step)
|
||||
|
||||
# gather loss before backprop in case of gradient accumulation
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||
@ -157,7 +161,13 @@ def train(accelerator, config):
|
||||
|
||||
if step > 0 and step % config["save_every"] == 0:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/step_{curr_step}",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||
val_loss = evaluate(model, val_dataloader)
|
||||
|
@ -11,7 +11,7 @@ from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
||||
from torchmetrics import MeanMetric
|
||||
from tqdm import tqdm
|
||||
from gpt4all.models import LetheForCausalLM
|
||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
|
||||
import wandb
|
||||
import pyarrow as pa
|
||||
from pyarrow import feather
|
||||
@ -58,13 +58,15 @@ def evaluate(model, index, config, val_dataloader, main_process=False):
|
||||
qa_labels = batch["labels"]
|
||||
outputs = model(input_ids=qa_inputs,
|
||||
labels=qa_labels,
|
||||
save_kv=False
|
||||
)
|
||||
|
||||
del memories
|
||||
torch.cuda.empty_cache()
|
||||
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
||||
|
||||
index.reset()
|
||||
for ind in index.values():
|
||||
ind.reset()
|
||||
val_loss.update(loss_values["loss"])
|
||||
|
||||
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
|
||||
@ -110,16 +112,18 @@ def train(accelerator, config):
|
||||
model_config = AutoConfig.from_pretrained(config["model_name"])
|
||||
|
||||
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||
index = MemoryIndex(head_size,
|
||||
indices = {i - 1: BatchedMemory(config["batch_size"],
|
||||
head_size,
|
||||
config["num_memories_per_index"],
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
model_config.num_attention_heads,
|
||||
) for i in config["memory_attn_layer"]}
|
||||
|
||||
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
||||
revision=config['version'] if 'version' in config else None,
|
||||
use_cache=False if checkpoint else True,
|
||||
memory_attn_layer=config["memory_attn_layer"],
|
||||
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
||||
index=index,
|
||||
index=indices,
|
||||
tracker=accelerator.get_tracker("wandb"),
|
||||
)
|
||||
|
||||
@ -206,8 +210,10 @@ def train(accelerator, config):
|
||||
|
||||
model.train()
|
||||
qa_inputs = batch["input_ids"]
|
||||
attn_mask = batch["attention_mask"]
|
||||
qa_labels = batch["labels"]
|
||||
outputs = model(input_ids=qa_inputs,
|
||||
attention_mask=attn_mask,
|
||||
labels=qa_labels,
|
||||
log_attn_scores=True,
|
||||
step=curr_step,
|
||||
@ -225,6 +231,7 @@ def train(accelerator, config):
|
||||
loss = loss / gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
# !! don't reset index until after backwards pass
|
||||
for index in indices.values():
|
||||
index.reset()
|
||||
# get gradient norm of all params
|
||||
|
||||
@ -251,7 +258,7 @@ def train(accelerator, config):
|
||||
)
|
||||
|
||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
|
||||
val_loss, loss_table = evaluate(model, indices, config, val_dataloader, main_process=main_process)
|
||||
|
||||
local_rank = accelerator.process_index
|
||||
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
|
||||
|
Loading…
Reference in New Issue
Block a user