This commit is contained in:
Zach Nussbaum 2023-07-27 17:00:02 +00:00
parent 55fef489ad
commit 3128db96ca
18 changed files with 761 additions and 135 deletions

View File

@ -14,7 +14,7 @@
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"stage": 1,
"offload_param": {
"device": "none"
},
@ -35,5 +35,15 @@
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
}
}

View File

@ -1,9 +1,10 @@
# model/tokenizer
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/qk_no_norm/step_500"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: false
gradient_checkpointing: true
memory_attn_layer: 12
seed: 42
# dataset

View File

@ -5,7 +5,7 @@ version: null
gradient_checkpointing: true
save_name: "nomic-ai/lethe"
push_to_hub: false
memory_attn_layer: 12
memory_attn_layer: [9, 12, 15]
# dataset
streaming: false
@ -17,7 +17,7 @@ pct_test: 0.05
q_column: "question"
a_column: "answer"
context_column: "text"
num_memories_per_index: 2000000
num_memories_per_index: 2048
num_neighbors_to_retrieve: 2
num_neighbors_to_store: 1
mem_chunk_size: 64
@ -26,15 +26,15 @@ mem_chunk_size: 64
lr: 1.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 100
save_every: 100
eval_every: 250
save_every: 250
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/mem_attn_no_cosine_sim"
output_dir: "ckpts/qk_no_norm"
checkpoint: null
lora: false
warmup_steps: 200
num_epochs: 5
num_epochs: 2
debug: false
scheduler: false
@ -43,4 +43,3 @@ wandb: true
wandb_entity: gpt4all
wandb_project_name: mem_attn
seed: 42

View File

@ -10,20 +10,21 @@ memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "JeanKaddour/minipile"
dataset_path: "pg19"
max_length: 2048
batch_size: 64
seq_len: 512
segments: 16
batch_size: 16
pct_test: 0.05
num_memories_per_index: 5000000
num_memories_per_index: 100000
mem_chunk_size: 512
num_chunks: 10
num_neighbors_to_retrieve: 32
# train dynamics
lr: 1.0e-4
lr: 2.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 100
eval_every: 250
save_every: -1
log_grads_every: 100
log_lr_every: 10
@ -38,6 +39,6 @@ scheduler: false
# logging
wandb: true
wandb_entity: gpt4all
wandb_project_name: minipile
wandb_project_name: enwik8
seed: 42

62
gpt4all/data/enwik8.py Normal file
View File

@ -0,0 +1,62 @@
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import DefaultDataCollator
class EnWik8Dataset(Dataset):
def __init__(self, data, seq_len):
# pyarrow chunked array
self.data = torch.from_numpy(data)
self.seq_len = seq_len
def __getitem__(self, index):
full_seq = self.data[index].long()
return full_seq.cuda()
def __len__(self):
return len(self.data)
def load_enwik8_dataloader(config, tokenizer):
ds = load_dataset(config["dataset_path"], split="train")
ds = ds.train_test_split(test_size=0.05, seed=config['seed'])
train_ds, val_ds = ds["train"], ds["test"]
keep_cols = ["input_ids"]
train_ds = train_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
train_ds = train_ds.sort("len")
train_ds = train_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
batched=True,
batch_size=config["batch_size"])
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
train_ds = train_ds.remove_columns(remove_cols)
val_ds = val_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
val_ds = val_ds.sort("len")
val_ds = val_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
batched=True,
batch_size=config["batch_size"])
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
val_ds = val_ds.remove_columns(remove_cols)
train_dl = DataLoader(train_ds,
batch_size=config["batch_size"],
shuffle=True,
drop_last=True,
collate_fn=DefaultDataCollator())
val_dl = DataLoader(val_ds,
batch_size=config["batch_size"],
shuffle=True,
drop_last=True,
collate_fn=DefaultDataCollator())
return train_dl, val_dl

View File

@ -1,6 +1,6 @@
import glob
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
import os
import hnswlib
from torch.utils.data import DataLoader
@ -12,20 +12,23 @@ def load_data(config, tokenizer):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
files = [dataset_path]
dataset = load_from_disk(dataset_path)
# if os.path.isdir(dataset_path):
# files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
# else:
# files = [dataset_path]
print(f"Reading files {files}")
# print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train")
# dataset = load_dataset("json", data_files=files, split="train")
else:
dataset = load_dataset(dataset_path, split="train")
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
dataset = dataset.map(lambda x: {"prompt": [text + " " + question for text, question in zip(x["text"], x["question"])]}, batched=True)
train_dataset, val_dataset = dataset["train"], dataset["test"]
if config["streaming"] is False:
@ -33,19 +36,27 @@ def load_data(config, tokenizer):
else:
kwargs = {}
cols_to_keep = ["input_ids", "labels", "attention_mask"]
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
batched=True,
remove_columns=["source", "prompt"],
# remove_columns=["source", "prompt"],
**kwargs
)
cols_to_remove = [col for col in train_dataset.column_names if col not in cols_to_keep]
train_dataset = train_dataset.remove_columns(cols_to_remove)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
batched=True,
remove_columns=["source", "prompt"],
# remove_columns=["source", "prompt"],
**kwargs
)
cols_to_remove = [col for col in val_dataset.column_names if col not in cols_to_keep]
val_dataset = val_dataset.remove_columns(cols_to_remove)
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")
@ -56,12 +67,14 @@ def load_data(config, tokenizer):
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
shuffle=True,
)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
shuffle=True,
)
return train_dataloader, val_dataloader

View File

@ -5,7 +5,7 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
# hacky backward compatible
different_eos = tokenizer.eos_token != "</s>"
out = {"labels": [], "input_ids": []}
out = {"labels": [], "input_ids": [], "attention_mask": []}
for prompt, response in zip(examples[input_col], examples[target_col]):
if different_eos:
if response.count("</s> \n") > 0:
@ -42,9 +42,10 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
print(response)
raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
tokenized = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)
out["labels"].append(labels)
out["input_ids"].append(input_tokens)
out["input_ids"].append(tokenized["input_ids"])
out["attention_mask"].append(tokenized["attention_mask"])
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}

View File

@ -41,7 +41,7 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase
if encoder_column != "encoder_hidden_states":
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"]
columns_to_keep = ["input_ids", "attention_mask", "labels", "encoder_hidden_states"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
@ -115,7 +115,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
**kwargs
)
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
columns_to_keep = ["id", "input_ids", "attention_mask", "labels", "retrieved_context"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
@ -128,12 +128,16 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
train_dataset.remove_columns("id"),
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
shuffle=True,
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
shuffle=True,
drop_last=True,
)
return train_dataloader, val_dataloader

View File

@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.utils.read import read_config
@ -28,17 +28,21 @@ def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
while True:
if num_new_tokens >= max_new_tokens:
break
outputs = model(input_ids, save_kv=False)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
outputs = model(input_ids, attention_mask=attention_mask, save_kv=False)
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
next_token_idx = torch.argmax((input_ids == tokenizer.pad_token_id).type(torch.float32))
# -1 because logits at last position predict next token
new_token = torch.argmax(outputs.logits[:, next_token_idx - 1, :], dim=-1)
input_ids[:, next_token_idx] = new_token
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
num_new_tokens += 1
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
if torch.equal(new_token.cpu(), torch.tensor(tokenizer.eos_token_id)):
break
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
print(f"GENERATED: {tokenizer.batch_decode(input_ids, skip_special_tokens=True)}")
return input_ids
@ -63,10 +67,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size,
index = BatchedMemory(config["batch_size"],
head_size,
config["num_memories_per_index"],
model_config.num_attention_heads
)
model_config.num_attention_heads,
)
model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
memory_attn_layer=config["memory_attn_layer"],
@ -90,16 +95,19 @@ with torch.no_grad():
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk.to(device))
del memories
torch.cuda.empty_cache()
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
for i in range(qa_inputs.shape[0]):
inputs = qa_inputs[i].to(device)
print(f"EXPECTED: {tokenizer.decode(inputs, skip_special_tokens=True)}")
labels = qa_labels[i].to(device)
cutoff = torch.argmax((labels != -100).type(torch.float32))
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
print(tokenizer.decode(inputs, skip_special_tokens=True))
inputs[cutoff:] = tokenizer.pad_token_id
greedy_search(inputs.unsqueeze(0).to(device), model, tokenizer)
print(f"CONTEXT: {tokenizer.decode(memories[i], skip_special_tokens=True)}")
import pdb; pdb.set_trace()
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))

View File

@ -0,0 +1,96 @@
{
"builder_name": "squad",
"citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n",
"config_name": "plain_text",
"dataset_size": 89819092,
"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
"download_checksums": {
"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {
"num_bytes": 30288272,
"checksum": null
},
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {
"num_bytes": 4854279,
"checksum": null
}
},
"download_size": 35142551,
"features": {
"id": {
"dtype": "string",
"_type": "Value"
},
"title": {
"dtype": "string",
"_type": "Value"
},
"context": {
"dtype": "string",
"_type": "Value"
},
"question": {
"dtype": "string",
"_type": "Value"
},
"answers": {
"feature": {
"text": {
"dtype": "string",
"_type": "Value"
},
"answer_start": {
"dtype": "int32",
"_type": "Value"
}
},
"_type": "Sequence"
},
"neighbor_ids": {
"feature": {
"dtype": "uint64",
"_type": "Value"
},
"_type": "Sequence"
},
"neighbor_text": {
"feature": {
"dtype": "string",
"_type": "Value"
},
"_type": "Sequence"
},
"loss": {
"dtype": "float64",
"_type": "Value"
}
},
"homepage": "https://rajpurkar.github.io/SQuAD-explorer/",
"license": "",
"size_in_bytes": 124961643,
"splits": {
"train": {
"name": "train",
"num_bytes": 79346108,
"num_examples": 87599,
"dataset_name": "squad"
},
"validation": {
"name": "validation",
"num_bytes": 10472984,
"num_examples": 10570,
"dataset_name": "squad"
}
},
"task_templates": [
{
"task": "question-answering-extractive"
}
],
"version": {
"version_str": "1.0.0",
"description": "",
"major": 1,
"minor": 0,
"patch": 0
}
}

View File

@ -0,0 +1,16 @@
{
"_data_files": [
{
"filename": "data-00000-of-00002.arrow"
},
{
"filename": "data-00001-of-00002.arrow"
}
],
"_fingerprint": "c178f5c8269012a2",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": "validation"
}

View File

@ -0,0 +1,42 @@
import torch
from gpt4all.data.instruction_tuning_dataloader import load_data
from gpt4all.utils.read import read_config
from transformers import AutoTokenizer, AutoModelForCausalLM
from argparse import ArgumentParser
from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_dataloader, val_dataloader = load_data(config, tokenizer)
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.half().to(device)
model.eval()
# Evaluate the model on the SQUAD dataset
f1s = []
exact_matches = []
with torch.no_grad():
for batch in tqdm(val_dataloader):
inputs = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
cutoff = torch.argmax((labels != -100).type(torch.float32))
outputs = model.generate(inputs[:, :cutoff], max_new_tokens=100)
print(f"Predicted: {tokenizer.batch_decode(outputs, skip_special_tokens=True)}")
print(f"Ground truth: {tokenizer.batch_decode(inputs[:, cutoff:], skip_special_tokens=True)}")
print(tokenizer.batch_decode(inputs, skip_special_tokens=True))

View File

@ -2,7 +2,6 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
from .lethe import LetheConfig, LetheForCausalLM

View File

@ -18,6 +18,8 @@ import wandb
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from typing import Optional, Tuple, Union
import torch
@ -121,58 +123,104 @@ class MemoryIndex:
# NOTE: we are storing kv pairs, instead indices for both keys and values
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
shape = (num_mems, nheads, 2, hidden_dim)
shape = (nheads, num_mems, 2, hidden_dim)
self.kv_pairs = np.zeros(shape, dtype=np.float32)
self.idx_offset = 0
def add(self, keys, values):
# k/v are (bs, num_attention_heads, seq_len, head_size)
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
# k/v are (num_attention_heads, seq_len, head_size)
# keys = keys.reshape(keys.shape[1], keys.shape[0], keys.shape[2])
# values = values.reshape(values.shape[1], values.shape[0], values.shape[2])
for head in range(self.nheads):
self.key_indices[head].add(reshaped_keys[:, head, :])
self.key_indices[head].add(keys[head, :, :])
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
kv_pairs = np.stack((keys, values), axis=2)
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
raise ValueError("Not enough memory!")
if self.idx_offset + kv_pairs.shape[1] > self.kv_pairs.shape[1]:
# reset to 0 to overwrite oldest memories
self.idx_offet = 0
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
self.idx_offset += kv_pairs.shape[0]
self.kv_pairs[:, self.idx_offset:self.idx_offset + kv_pairs.shape[1]] = kv_pairs
self.idx_offset += kv_pairs.shape[1]
def knn_query(self, query, k=1):
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
mem_keys = []
mem_values = []
mem_indices = []
# we can prob make this better
for head in range(self.nheads):
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
knn_indices = self.key_indices[head].query(query[head, :, :], k=k)
kv_pairs = self.kv_pairs[head, :, :, :][knn_indices]
mem_keys.append(kv_pairs[:, :, 0, :])
mem_values.append(kv_pairs[:, :, 1, :])
mem_indices.append(knn_indices)
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
# (bs, num_attention_heads, seq_len, k, head_size)
mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=0))
# (num_attention_heads, seq_len, k, head_size)
# mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
mem_values = torch.from_numpy(np.stack(mem_values, axis=1))
# (bs, num_attention_heads, seq_len, k, head_size)
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
mem_values = torch.from_numpy(np.stack(mem_values, axis=0))
# (num_attention_heads, seq_len, k, head_size)
# mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
return mem_keys, mem_values, np.stack(mem_indices, axis=1)
return mem_keys, mem_values, np.stack(mem_indices, axis=0)
def reset(self):
for head in range(self.nheads):
self.key_indices[head].reset()
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
self.kv_pairs = np.zeros(self.kv_pairs.shape, dtype=np.float32)
self.idx_offset = 0
class BatchedMemory:
def __init__(self, batch_size, hidden_dim, num_mems, nheads):
self.indices = [MemoryIndex(hidden_dim, num_mems, nheads) for _ in range(batch_size)]
def add(self, keys, values):
for bs in range(len(self.indices)):
self.indices[bs].add(keys[bs], values[bs])
def knn_query(self, query, k=1):
batched_mem_keys = []
batched_mem_values = []
batched_labels = []
for bs in range(len(self.indices)):
knn_keys, knn_values, knn_labels = self.indices[bs].knn_query(query[bs], k=k)
batched_mem_keys.append(knn_keys)
batched_mem_values.append(knn_values)
batched_labels.append(knn_labels)
return torch.stack(batched_mem_keys, dim=0), torch.stack(batched_mem_values, dim=0), np.stack(batched_labels, axis=0)
def reset(self):
for bs in range(len(self.indices)):
self.indices[bs].reset()
class BatchedStorage:
def __init__(self, batch_size, hidden_dim, nheads, seq_len):
self.indices = np.zeros((batch_size, nheads, seq_len, 2, hidden_dim), dtype=np.float32)
def knn_query(self, query, k=None):
return torch.from_numpy(self.indices[:, :, :, 0, :]), torch.from_numpy(self.indices[:, :, :, 1, :])
def add(self, keys, values):
self.indices[:, :, :, 0, :] = keys
self.indices[:, :, :, 1, :] = values
def reset(self):
self.indices = np.zeros_like(self.indices)
class LethePreTrainedModel(PreTrainedModel):
@ -206,12 +254,13 @@ class LethePreTrainedModel(PreTrainedModel):
class LetheAttention(nn.Module):
def __init__(self, config, memory_attention=False, index=None, tracker=None):
def __init__(self, config, memory_attention=False, index=None, layer_idx=None, tracker=None):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
self.layer_idx = layer_idx
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
@ -274,7 +323,6 @@ class LetheAttention(nn.Module):
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# if self.memory:
if self.memory:
# QKNorm: https://arxiv.org/abs/2010.04245
query = F.normalize(query, dim=-1)
@ -309,17 +357,28 @@ class LetheAttention(nn.Module):
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
# if log_attn_scores:
# batch_size = query.shape[0]
# seq_len = query.shape[-2]
# key_labels = knn_labels // seq_len
# key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
# correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
# # calculate the accuracy
# key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
# self.tracker.log({"retrieved_acc": key_acc}, step=step)
if log_attn_scores:
batch_size = query.shape[0]
seq_len = query.shape[-2]
total_examples = 0
unique_examples = 0
for bs in range(query.shape[0]):
for head in range(query.shape[1]):
labels_per_head = knn_labels[bs, head, :, 0].tolist()
total_examples += len(labels_per_head)
unique_examples += len(set(labels_per_head))
key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
# calculate the accuracy
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
self.tracker.log({"retrieved_acc": key_acc}, step=step)
self.tracker.log({"unique_retrieved_pct": unique_examples / total_examples}, step=step)
attn_output = self._mem_attn(query,
knn_keys.to(query.device).to(value.dtype),
@ -407,6 +466,7 @@ class LetheAttention(nn.Module):
local_attn_scores = local_attn_scores + attention_mask
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
# mem_attn_scores = torch.matmul(query, knn_key.transpose(-1, -2))
# attn_scores: [bs, seq_len, num_attention_heads, knn]
mem_attn_scores = mem_attn_scores * scale
@ -417,48 +477,56 @@ class LetheAttention(nn.Module):
attn_weights = attn_weights.to(local_value.dtype)
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
if log_attn_scores:
# (bs, seq_len, num_attention_heads, knn) probabilities
# curate (x,y) pairs
# where x is attention weight, y is accuracy of retrieved token
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
# mem_attn_weights, local_attn_weights = attn_weights.chunk(2, dim=-1)
bin_width = 0.05
# Calculate the number of bins
num_bins = int(1 / bin_width)
# if log_attn_scores:
# # (bs, seq_len, num_attention_heads, knn) probabilities
# # curate (x,y) pairs
# # where x is attention weight, y is accuracy of retrieved token
# bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
# key_labels = knn_labels // seq_len
# key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
# correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
# Create empty lists for storing bin probabilities and accuracies
bin_probabilities = []
bin_accuracies = []
# bin_width = 0.05
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
correct_keys = correct_keys.reshape(-1).tolist()
# # Calculate the number of bins
# num_bins = int(1 / bin_width)
# Iterate over each bin
for i in range(num_bins):
bin_lower = i * bin_width
bin_upper = (i + 1) * bin_width
# # Create empty lists for storing bin probabilities and accuracies
# bin_probabilities = []
# bin_accuracies = []
# bin_sizes = []
# Filter data points within the current bin range
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
# probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
# correct_keys = correct_keys.reshape(-1).tolist()
# Calculate accuracy for the bin
total = len(bin_x_values)
correct = sum(bin_y_values)
accuracy = correct / total if total > 0 else 0
# # Iterate over each bin
# for i in range(num_bins):
# bin_lower = i * bin_width
# bin_upper = (i + 1) * bin_width
# Store the probability and accuracy for the bin
bin_probabilities.append((bin_lower + bin_upper) / 2)
bin_accuracies.append(accuracy)
# # Filter data points within the current bin range
# bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
# bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
# # Calculate accuracy for the bin
# total = len(bin_x_values)
# correct = sum(bin_y_values)
# accuracy = correct / total if total > 0 else 0
# # Store the probability and accuracy for the bin
# bin_probabilities.append((bin_lower + bin_upper) / 2)
# bin_accuracies.append(accuracy)
# bin_sizes.append(len(bin_x_values))
# df = pd.DataFrame({"attn_prob": bin_probabilities, "retrieved_acc": bin_accuracies, "bin_size": bin_sizes})
# fig = px.scatter(df, x="attn_prob", y="retrieved_acc",
# color="bin_size", hover_data=["attn_prob", "retrieved_acc", "bin_size"],
# title="Attention Probability vs Retrieved Accuracy")
# self.tracker.log({"attn_vs_acc": fig}, step=step)
if log_attn_scores:
@ -470,10 +538,10 @@ class LetheAttention(nn.Module):
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
mem_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
plt.title(f"mem_attn_score_{head}")
plt.title(f"mem_attn_score_{head}_layer_{self.layer_idx}")
# set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000))
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
self.tracker.log({f"mem_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
plt.close()
@ -482,15 +550,16 @@ class LetheAttention(nn.Module):
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
local_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(local_hist.tolist(), local_bins.tolist())
plt.title(f"local_attn_score_{head}")
plt.title(f"local_attn_score_{head}_layer_{self.layer_idx}")
# set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000))
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
self.tracker.log({f"local_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
plt.close()
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
# mem_attn_output = torch.matmul(mem_attn_weights, knn_value)
local_attn_output = torch.matmul(local_attn_weights, local_value)
# TODO: do we need flamingo style gating
@ -524,8 +593,6 @@ class LetheAttention(nn.Module):
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
if self.memory:
attn_scores = attn_scores * self.scale.exp()
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
@ -610,12 +677,15 @@ class LetheMLP(nn.Module):
class LetheLayer(nn.Module):
def __init__(self, config, memory_attention=False, index=None, tracker=None):
def __init__(self, config, memory_attention=False, layer_idx=None, index=None, tracker=None):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
self.attention = LetheAttention(config, memory_attention=memory_attention,
layer_idx=layer_idx,
index=index[layer_idx] if memory_attention else None,
tracker=tracker)
self.mlp = LetheMLP(config)
def forward(
@ -676,8 +746,9 @@ class LetheModel(LethePreTrainedModel):
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([LetheLayer(config,
memory_attention=i+1 == config.memory_attn_layer,
index=index if i+1 == config.memory_attn_layer else None,
memory_attention=i+1 in config.memory_attn_layer,
layer_idx=i,
index=index,
tracker=tracker)
for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

View File

@ -1,2 +0,0 @@
from .configuration_pythia_retro import PythiaRetroConfig
from .modeling_pythia_retro import PythiaRetroForCausalLM

View File

@ -0,0 +1,288 @@
import os
import torch.nn.functional as F
from transformers import AutoTokenizer, get_scheduler, AutoConfig
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
from gpt4all.utils.read import read_config
from accelerate import Accelerator
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
from gpt4all.data.enwik8 import load_enwik8_dataloader
from torchmetrics import MeanMetric
from tqdm import tqdm
from gpt4all.models import LetheForCausalLM, LetheConfig
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
import wandb
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
return log
def evaluate(model, index, pad_token_id, config, val_dataloader, main_process=False):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
chunk_size = config["seq_len"]
with torch.no_grad():
for batch in tqdm(val_dataloader, disable=not main_process):
seq_len = batch.shape[1]
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = min(seq_len, chunk_start + chunk_size)
inputs = batch[:, chunk_start:chunk_end].to(model.device)
labels = inputs.clone()
outputs = model(input_ids=inputs,
attention_mask=inputs.ne(pad_token_id),
labels=labels,
log_attn_scores=False,
step=None,
save_kv=True,
)
loss = outputs.loss / config["segments"]
loss_values = accelerator.gather_for_metrics({"loss": loss.item()})
val_loss.update(loss_values["loss"])
index.reset()
return val_loss
def train(accelerator, config):
set_seed(config['seed'])
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
train_dataloader, val_dataloader = load_enwik8_dataloader(config, tokenizer)
if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
# instead of decaying to zero, decay to ratio of min_lr / lr
accelerator.print(f"Total training steps: {total_num_steps}")
checkpoint = config["gradient_checkpointing"]
model_config = LetheConfig.from_pretrained(config["model_name"])
model_config.memory_attn_layer = config["memory_attn_layer"]
model_config.num_neighbors_to_retrieve = config["num_neighbors_to_retrieve"]
model_config.use_cache = False if checkpoint else True
head_size = model_config.hidden_size // model_config.num_attention_heads
index = BatchedMemory(config["batch_size"],
head_size,
config["num_memories_per_index"],
model_config.num_attention_heads,
)
model = LetheForCausalLM(model_config,
index=index,
tracker=accelerator.get_tracker("wandb"))
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
if checkpoint:
model.gradient_checkpointing_enable()
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config:
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, scheduler, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, scheduler, train_dataloader, val_dataloader
)
use_scheduler = True
else:
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader
)
use_scheduler = False
# setup for saving training states in case preemption
if use_scheduler:
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
main_process = accelerator.is_main_process
chunk_size = config["seq_len"]
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
epoch_step = epoch * len(train_dataloader) + step * config["segments"]
seq_len = batch["input_ids"].shape[1]
model.train()
for i, chunk_start in enumerate(range(0, seq_len, chunk_size)):
curr_step = epoch_step + i
chunk_end = min(seq_len, chunk_start + chunk_size)
inputs = batch["input_ids"][:, chunk_start:chunk_end]
labels = inputs.clone()
labels[labels == tokenizer.pad_token_id] = -100
outputs = model(input_ids=inputs,
attention_mask=inputs.ne(tokenizer.pad_token_id),
labels=labels,
log_attn_scores=True,
step=curr_step,
save_kv=True,
)
loss = outputs.loss / config["segments"]
if config["wandb"]:
accelerator.log({"loss": loss}, step=curr_step)
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# log LR in case something weird happens
if config["wandb"]:
if step > 0 and step % (config["log_lr_every"] ) == 0:
lr = optimizer.param_groups[0]["lr"]
accelerator.log({"lr": lr}, step=curr_step)
optimizer.step()
if use_scheduler:
scheduler.step()
optimizer.zero_grad()
# reset index on batch end
index.reset()
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/step_{step}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, index, tokenizer.pad_token_id, config, val_dataloader, main_process=main_process)
log_train = {
"train_loss": train_loss.compute()
}
log_val = {
"val_loss": val_loss.compute(),
}
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if config["push_to_hub"]:
accelerator.print(f"Pushing to HF hub")
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()
if __name__ == "__main__":
# parse arguments by reading in a config
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
if config["wandb"]:
accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=config["wandb_project_name"],
config=config,
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
)
else:
accelerator = Accelerator()
train(accelerator, config=config)

View File

@ -116,6 +116,7 @@ def train(accelerator, config):
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
import pdb; pdb.set_trace()
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
@ -131,9 +132,12 @@ def train(accelerator, config):
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)):
curr_step = step + epoch * len(train_dataloader)
model.train()
outputs = model(**batch)
loss = outputs.loss
if config["wandb"]:
accelerator.log({"loss": loss}, step=curr_step)
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
@ -157,7 +161,13 @@ def train(accelerator, config):
if step > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/step_{curr_step}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, val_dataloader)

View File

@ -11,7 +11,7 @@ from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from torchmetrics import MeanMetric
from tqdm import tqdm
from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
import wandb
import pyarrow as pa
from pyarrow import feather
@ -58,13 +58,15 @@ def evaluate(model, index, config, val_dataloader, main_process=False):
qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs,
labels=qa_labels,
save_kv=False
)
del memories
torch.cuda.empty_cache()
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
index.reset()
for ind in index.values():
ind.reset()
val_loss.update(loss_values["loss"])
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
@ -110,16 +112,18 @@ def train(accelerator, config):
model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size,
indices = {i - 1: BatchedMemory(config["batch_size"],
head_size,
config["num_memories_per_index"],
model_config.num_attention_heads
)
model_config.num_attention_heads,
) for i in config["memory_attn_layer"]}
model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True,
memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index,
index=indices,
tracker=accelerator.get_tracker("wandb"),
)
@ -206,8 +210,10 @@ def train(accelerator, config):
model.train()
qa_inputs = batch["input_ids"]
attn_mask = batch["attention_mask"]
qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs,
attention_mask=attn_mask,
labels=qa_labels,
log_attn_scores=True,
step=curr_step,
@ -225,6 +231,7 @@ def train(accelerator, config):
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# !! don't reset index until after backwards pass
for index in indices.values():
index.reset()
# get gradient norm of all params
@ -251,7 +258,7 @@ def train(accelerator, config):
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
val_loss, loss_table = evaluate(model, indices, config, val_dataloader, main_process=main_process)
local_rank = accelerator.process_index
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")