fix: wip mem xf

This commit is contained in:
Zach Nussbaum 2023-06-02 23:02:25 +00:00
parent 3677935ce8
commit 55fef489ad
15 changed files with 860 additions and 133 deletions

View File

@ -0,0 +1,23 @@
# model/tokenizer
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: false
memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
# dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
max_length: 1024
batch_size: 1
pct_test: 0.05
q_column: "question"
a_column: "answer"
context_column: "text"
num_memories_per_index: 2000000
num_neighbors_to_retrieve: 2
num_neighbors_to_store: 1
mem_chunk_size: 64

View File

@ -0,0 +1,19 @@
model_name: "nomic-ai/gpt4all-j"
revision: "v1.3-groovy"
tokenizer_name: "nomic-ai/gpt4all-j"
# dataset
streaming: false
num_proc: 64
dataset_path: "nomic-ai/cohere-wiki-sbert"
batch_size: 32
output_path: "synth_qa_pairs"
# generation
max_new_tokens: 75
max_generations: 200000
save_every: 1000
seed: 42

View File

@ -10,35 +10,36 @@ memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
max_length: 1024
batch_size: 64
batch_size: 32
pct_test: 0.05
q_column: "question"
a_column: "answers"
context_column: "neighbor_text"
num_memories_per_index: 500000
num_neighbors_to_retrieve: 32
a_column: "answer"
context_column: "text"
num_memories_per_index: 2000000
num_neighbors_to_retrieve: 2
num_neighbors_to_store: 1
mem_chunk_size: 64
# train dynamics
lr: 1.0e-4
lr: 1.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 100
save_every: -1
save_every: 100
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/mem_attn"
output_dir: "ckpts/mem_attn_no_cosine_sim"
checkpoint: null
lora: false
warmup_steps: 500
warmup_steps: 200
num_epochs: 5
debug: false
scheduler: false
# logging
wandb: false
wandb: true
wandb_entity: gpt4all
wandb_project_name: mem_attn
seed: 42

View File

@ -0,0 +1,43 @@
# model/tokenizer
model_name: "EleutherAI/pythia-1b"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/minipille"
push_to_hub: false
memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "JeanKaddour/minipile"
max_length: 2048
batch_size: 64
pct_test: 0.05
num_memories_per_index: 5000000
mem_chunk_size: 512
num_chunks: 10
num_neighbors_to_retrieve: 32
# train dynamics
lr: 1.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 100
save_every: -1
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/minipile"
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 5
debug: false
scheduler: false
# logging
wandb: true
wandb_entity: gpt4all
wandb_project_name: minipile
seed: 42

View File

@ -85,6 +85,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
question_col = config["q_column"]
answer_col = config["a_column"]
context_col = config["context_column"]
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
@ -96,7 +97,8 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
# a list of the answer
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
dataset = dataset.map(lambda ele: {answer_col: [t.strip() for t in ele[answer_col]]}, batched=True)
# dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
dataset = dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
@ -106,19 +108,73 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
# tokenize contexts for each example
dataset = dataset.map(
lambda ele: {"retrieved_context": tokenizer(ele["context"],
lambda ele: {"retrieved_context": tokenizer([ele[context_col]],
return_tensors="pt",
padding="max_length",
truncation=True)["input_ids"]},
batched=True,
**kwargs
)
columns_to_keep = ["input_ids", "labels", "retrieved_context"]
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
if split_dataset:
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataloader = DataLoader(
train_dataset.remove_columns("id"),
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return train_dataloader, val_dataloader
else:
dataloader = DataLoader(
dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return dataloader
def load_memory_pretraining_data(config, tokenizer, split="train", split_dataset=True):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
dataset = Dataset.load_from_disk(dataset_path)
else:
dataset = load_dataset(dataset_path, split=split)
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# e.g. 512 * 10 = 5120 sequence length split up
max_length = config["mem_chunk_size"] * config["num_chunks"]
dataset = dataset.map(lambda ele: tokenizer(ele["text"], padding="max_length", truncation=True, max_length=max_length),
batched=True, **kwargs)
dataset = dataset.map(lambda x: {"labels": x["input_ids"]}, batched=True, **kwargs)
columns_to_keep = ["input_ids", "labels", "attention_mask"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
# we can shuffle since the docs are in one row not split across rows
if split_dataset:
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]

View File

@ -0,0 +1,116 @@
import torch
import torch.nn.functional as F
from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.utils.read import read_config
from transformers import AutoTokenizer, AutoConfig
from argparse import ArgumentParser
from tqdm import tqdm
from nomic import atlas
from datasets import load_from_disk
def calc_loss_per_item(logits, labels):
lm_logits = logits[:, :-1, :].contiguous()
lm_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
# return tensor of shape (B,) where B is the batch size
return loss.cpu().tolist()
def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
num_new_tokens = 0
with torch.no_grad():
while True:
if num_new_tokens >= max_new_tokens:
break
outputs = model(input_ids, save_kv=False)
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
num_new_tokens += 1
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
break
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
return input_ids
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"], model_max_length=config["max_length"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
dataloader = load_memory_augmented_data(config, tokenizer, split_dataset=False)
dataset = load_from_disk(config["dataset_path"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size,
config["num_memories_per_index"],
model_config.num_attention_heads
)
model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index,
).to(device)
model.eval()
# Evaluate the model on the SQUAD dataset
losses = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader)):
memories = batch["retrieved_context"]
memories = memories[:, :config["num_neighbors_to_store"], :]
memories = memories.reshape(-1, memories.shape[-1])
# need to set to eval so we don't do mem attn as it's slow
with torch.no_grad():
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk.to(device))
del memories
torch.cuda.empty_cache()
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
for i in range(qa_inputs.shape[0]):
inputs = qa_inputs[i].to(device)
labels = qa_labels[i].to(device)
cutoff = torch.argmax((labels != -100).type(torch.float32))
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
print(tokenizer.decode(inputs, skip_special_tokens=True))
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
# losses.extend(batch_loss)
index.reset()
dataset = dataset.add_column("loss", losses)
dataset.save_to_disk("eval_squad_atlas_map")

View File

@ -0,0 +1,57 @@
import glob
from argparse import ArgumentParser
from datasets import Dataset, load_from_disk, concatenate_datasets
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
def load_synth_data(data_dir):
files = glob.glob(data_dir + "/*")
ds = concatenate_datasets([load_from_disk(f) for f in files])
table = ds.data.table
filtered = table.filter(table["valid"])
ds = Dataset.from_dict(filtered.to_pydict())
return ds
def remove_prompt(examples):
outputs = {"text": [], "generated": [], "question": [], "answer": []}
for context, generated in zip(examples["text"], examples["generated"]):
prompt_w_ctx = PROMPT.format(context)
gen_wo_ctx = generated[len(prompt_w_ctx):]
assert prompt_w_ctx not in gen_wo_ctx
question = gen_wo_ctx.split("Answer:")[0].replace("Question:", "").strip()
answer = gen_wo_ctx.split("Answer:")[1].strip()
outputs["text"].append(context)
outputs["generated"].append(gen_wo_ctx)
outputs["question"].append(question)
outputs["answer"].append(answer)
return outputs
def combine_synth_data(data_dir):
ds = load_synth_data(data_dir)
ds = ds.map(lambda ele: remove_prompt(ele), batched=True, num_proc=64)
ds.save_to_disk(f"synth_data_combined_{len(ds)/1000:.0f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--dataset_dir", type=str, required=True)
args = parser.parse_args()
combine_synth_data(args.dataset_dir)

View File

@ -0,0 +1,149 @@
from argparse import ArgumentParser
from datasets import load_dataset, Dataset
import torch
from torch.utils.data import DataLoader, DistributedSampler
from accelerate.utils import set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM, DefaultDataCollator
from gpt4all.utils.read import read_config
from gpt4all.utils.distributed_utils import rank0_print, main_process_first
from tqdm import tqdm
import pyarrow.compute as pc
import pyarrow as pa
import torch.distributed as dist
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
def prepare_data(config, tokenizer, num_processes, local_rank):
dataset = load_dataset(config["dataset_path"], split="train")
dataset = dataset.remove_columns("embedding")
shuffled = dataset.shuffle(seed=config["seed"])
indices = shuffled[:config["max_generations"]]["id"]
table = dataset.data
mask = pc.is_in(table["id"], value_set=pa.array(indices, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
orig_dataset = Dataset.from_dict(filtered_table.to_pydict())
dataset = orig_dataset.map(lambda ele: {"prompted_text": [PROMPT.format(context) for context in ele["text"]]},
batched=True,
num_proc=config["num_proc"] if "num_proc" in config else None)
dataset = dataset.map(lambda ele: {"prompt_len": [len(prompt) for prompt in ele["prompted_text"]]}, batched=True,
num_proc=config["num_proc"] if "num_proc" in config else None)
dataset = dataset.sort("prompt_len")
dataset = dataset.map(lambda ele: tokenizer(ele["prompted_text"], return_tensors="pt", padding="longest", truncation=True,
max_length=tokenizer.model_max_length - config["max_new_tokens"]), batched=True,
batch_size=num_processes * config["batch_size"],
)
columns_to_keep = ["id", "input_ids", "attention_mask"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
sampler = DistributedSampler(
dataset,
shuffle=False,
drop_last=True,
num_replicas=num_processes,
rank=local_rank,
)
dataloader = DataLoader(
dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
sampler=sampler,
drop_last=True
)
return dataloader, orig_dataset
def generate_data(config):
set_seed(config["seed"])
rank0_print(f"World size: {dist.get_world_size()}")
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
# since we're doing generation, pad left for autoregressive generation
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
dataloader, dataset = prepare_data(config, tokenizer, num_processes, local_rank)
dist.barrier()
print(dataset[:10]["id"])
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
revision=config["revision"] if "revision" in config else None,
use_cache=True,
torch_dtype=torch.bfloat16,)
model.to(f"cuda:{local_rank}")
synth_data = []
valid = []
ids = []
total_valid = 0
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, disable=local_rank != 0)):
# keep this simple for now, can add temperature and other sampling techniques later
generated = model.generate(batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
max_new_tokens=config["max_new_tokens"])
decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
num_valid = ["\nQuestion:" in t and "\nAnswer:" in t for t in decoded]
rank0_print(f"Num valid: {sum(num_valid)/ len(num_valid):.2f}")
total_valid += sum(num_valid)
synth_data.extend(decoded)
valid.extend(num_valid)
ids.extend(batch["id"].tolist())
if i > 0 and i % config["save_every"] == 0:
table = dataset.data.table
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
filtered_table = table.filter(mask)
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
joined = filtered_table.join(chunk_table, "id")
curr_dataset = Dataset.from_dict(joined.to_pydict())
curr_dataset.save_to_disk(f'{config["output_path"]}/chunk_{i}_rank_{local_rank}')
table = dataset.data.table
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
filtered_table = table.filter(mask)
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
joined = filtered_table.join(chunk_table, "id")
full_dataset = Dataset.from_dict(joined.to_pydict())
full_dataset.save_to_disk(f'{config["output_path"]}_{config["max_generations"]}_rank_{local_rank}')
rank0_print(f"Total valid: {total_valid}/{config['max_generations']}")
def main():
dist.init_process_group("nccl")
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
generate_data(config)
if __name__ == "__main__":
# parse arguments by reading in a config
main()

View File

@ -111,6 +111,7 @@ class LetheConfig(PretrainedConfig):
memory_attn_layer=9,
num_neighbors_to_retrieve=32,
num_neighbors_stored=128,
attn_scale_init=20.0,
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -133,4 +134,5 @@ class LetheConfig(PretrainedConfig):
# index of cross attention layer to add
self.memory_attn_layer = memory_attn_layer
self.num_neighbors_to_retrieve = num_neighbors_to_retrieve
self.num_neighbors_stored = num_neighbors_stored
self.num_neighbors_stored = num_neighbors_stored
self.attn_scale_init = attn_scale_init

View File

@ -12,8 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch PythiaSeek model."""
""" PyTorch Lethe model."""
import wandb
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Union
import torch
@ -29,8 +33,9 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from gpt4all.models.lethe import LetheConfig
import hnswlib
import numpy as np
import faiss
import faiss.contrib.torch_utils
logger = logging.get_logger(__name__)
@ -40,17 +45,18 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
"EleutherAI/gpt-neox-20b",
]
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
# TODO: do we need to implement masking for the dense vectors we pull from?
# TODO: i think phil is using a memmapped database to pull out rather than using the index
class HNSWIndex:
def __init__(self, max_memories, dimension):
# num_memories will be batch size * num_neighbors
# can memmap this too like
self.index = hnswlib.Index(space="l2", dim=dimension)
self.index.init_index(max_elements=max_memories, ef_construction=50, M=16)
self.index = faiss.IndexHNSWFlat(dimension, 16, faiss.METRIC_INNER_PRODUCT)
# taking params from: https://www.pinecone.io/learn/vector-indexes/#hnsw-implementation
# and https://www.pinecone.io/learn/hnsw/#hnsw-performance
# seems like efConstruction dictates how long the index takes to build
# and efSearch and M (second arg to faiss.Index) dictates how long it takes to search
self.index.hnsw.efConstruction = 16
self.index.hnsw.efSearch = 32
self.max_memories = max_memories
self.dimension = dimension
@ -60,45 +66,65 @@ class HNSWIndex:
def query(self, query, k=1):
# hack what should we do here?
if self.index.get_current_count() == 0:
return np.ones((query.shape[0], k, query.shape[1]), dtype=np.float32)
if self.index.ntotal == 0:
return np.ones((query.shape[0], k), dtype=np.int32)
assert query.ndim == 2
bs_seq_len, _ = query.shape
_, labels = self.index.search(np.ascontiguousarray(query), k=k)
labels, _ = self.index.knn_query(query, k=k)
neighbors = torch.tensor(self.index.get_items(labels.reshape(-1)))
neighbors = neighbors.reshape((bs_seq_len, k, query.shape[1]))
assert neighbors.ndim == 3
assert neighbors.shape[0] == bs_seq_len
return neighbors
return labels
def add(self, memories):
assert memories.ndim == 2
bs_seq_len, _ = memories.shape
ids = np.arange(self.idx_offset, self.idx_offset + bs_seq_len)
self.index.add_items(memories, ids)
self.idx_offset += bs_seq_len
return self.index.add(np.ascontiguousarray(memories))
def reset(self):
self.index = hnswlib.Index(space="l2", dim=self.dimension)
self.index.init_index(max_elements=self.max_memories, ef_construction=50, M=16)
self.index.reset()
class NumpyKNNIndex:
def __init__(self, max_memories, dimension):
# num_memories will be batch size * num_neighbors
# can memmap this too like
self.index = np.zeros((max_memories, dimension), dtype=np.float32)
self.max_memories = max_memories
self.dimension = dimension
# if we want to allow for insertion of len(elements) > max_memories
# we need to figure out a way to get the most recent memories
self.idx_offset = 0
def query(self, query, k=1):
# hack what should we do here?
if self.index.sum() == 0:
return np.ones((query.shape[0], k), dtype=np.int32)
dots = query.dot(self.index[:self.idx_offset].T)
labels = np.argsort(dots, axis=1)[:, -k:]
return labels
def add(self, memories):
self.index[self.idx_offset:self.idx_offset + memories.shape[0]] = memories
self.idx_offset += memories.shape[0]
def reset(self):
self.index.reset()
class MemoryIndex:
def __init__(self, hidden_dim, num_mems, nheads):
# we store an index for each k/v for each head
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
self.value_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
self.nheads = nheads
# NOTE: we are storing kv pairs, instead indices for both keys and values
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
shape = (num_mems, nheads, 2, hidden_dim)
self.kv_pairs = np.zeros(shape, dtype=np.float32)
self.idx_offset = 0
def add(self, keys, values):
# k/v are (bs, num_attention_heads, seq_len, head_size)
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
@ -106,22 +132,30 @@ class MemoryIndex:
for head in range(self.nheads):
self.key_indices[head].add(reshaped_keys[:, head, :])
self.value_indices[head].add(reshaped_values[:, head, :])
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
raise ValueError("Not enough memory!")
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
self.idx_offset += kv_pairs.shape[0]
def knn_query(self, query, k=1):
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
mem_keys = []
mem_values = []
mem_indices = []
# this is prob so so slow
# we can prob make this better
for head in range(self.nheads):
knn_keys = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
knn_values = self.value_indices[head].query(reshaped_query[:, head, :], k=k)
mem_keys.append(knn_keys)
mem_values.append(knn_values)
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
mem_keys.append(kv_pairs[:, :, 0, :])
mem_values.append(kv_pairs[:, :, 1, :])
mem_indices.append(knn_indices)
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
# (bs, num_attention_heads, seq_len, k, head_size)
@ -131,13 +165,14 @@ class MemoryIndex:
# (bs, num_attention_heads, seq_len, k, head_size)
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
return mem_keys, mem_values
return mem_keys, mem_values, np.stack(mem_indices, axis=1)
def reset(self):
for head in range(self.nheads):
self.key_indices[head].reset()
self.value_indices[head].reset()
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
class LethePreTrainedModel(PreTrainedModel):
@ -171,7 +206,7 @@ class LethePreTrainedModel(PreTrainedModel):
class LetheAttention(nn.Module):
def __init__(self, config, memory_attention=False, index=None):
def __init__(self, config, memory_attention=False, index=None, tracker=None):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
@ -188,21 +223,24 @@ class LetheAttention(nn.Module):
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self.register_buffer(
if not memory_attention:
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.memory = False
if memory_attention:
self.scale = nn.Parameter(torch.ones(self.num_attention_heads, 1, 1) * math.log(config.attn_scale_init))
self.memory = True
self.alpha = nn.Parameter(torch.zeros(self.num_attention_heads))
self.num_neighbors = config.num_neighbors_to_retrieve
# for testing, just using np array since it's easy
self.index = index
self.tracker = tracker
def forward(
@ -214,7 +252,9 @@ class LetheAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True,
log_attn_scores: Optional[bool] = False,
step: Optional[int] = None,
save_kv: Optional[bool] = True,
):
has_layer_past = layer_past is not None
@ -234,6 +274,12 @@ class LetheAttention(nn.Module):
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# if self.memory:
if self.memory:
# QKNorm: https://arxiv.org/abs/2010.04245
query = F.normalize(query, dim=-1)
key = F.normalize(key, dim=-1)
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
@ -257,27 +303,38 @@ class LetheAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# TODO: need to do masking??
# memory attention
if self.memory:
# get knns
# since we do an eval batch w context before, let's not do the expensive step until we need to
# [batch, knn, num_attention_heads, seq_len, head_size]
if use_mem_attn:
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
mem_attn = self._mem_attn(query,
knn_keys.to(query.device),
knn_values.to(query.device),
attention_mask,
head_mask
)
if save_kv:
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
expanded_alpha = self.alpha[None, :, None, None]
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
if log_attn_scores:
batch_size = query.shape[0]
seq_len = query.shape[-2]
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
# calculate the accuracy
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
self.tracker.log({"retrieved_acc": key_acc}, step=step)
attn_output = self._mem_attn(query,
knn_keys.to(query.device).to(value.dtype),
knn_values.to(query.device).to(value.dtype),
key,
value,
attention_mask,
head_mask,
log_attn_scores=log_attn_scores,
step=step,
knn_labels=knn_labels,
)
else:
# Normal self-attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
@ -315,28 +372,131 @@ class LetheAttention(nn.Module):
return tensor
def _mem_attn(self, query, key, value, attention_mask=None, head_mask=None):
def _mem_attn(self,
query,
knn_key,
knn_value,
local_key,
local_value,
attention_mask=None,
head_mask=None,
log_attn_scores=False,
step=None,
knn_labels=None):
# local self-attention
# q: [bs, num_attention_heads, seq_len, attn_head_size]
# k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size]
query_length = query.size(-2)
key_length = local_key.size(-2)
attn_scores = torch.einsum("bhsd, bhsnd-> bshn", query, key)
# attn_scores: [bs, seq_len, num_attention_heads, knn]
attn_scores = attn_scores / self.norm_factor
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# softmax over knns
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
local_attn_scores = torch.matmul(query, local_key.transpose(-1, -2))
scale = self.scale.exp()
local_attn_scores = local_attn_scores * scale
mask_value = torch.finfo(local_attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=local_attn_scores.dtype).to(local_attn_scores.device)
local_attn_scores = torch.where(causal_mask, local_attn_scores, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_scores = attn_scores + attention_mask
local_attn_scores = local_attn_scores + attention_mask
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
# attn_scores: [bs, seq_len, num_attention_heads, knn]
mem_attn_scores = mem_attn_scores * scale
attn_scores = torch.cat((mem_attn_scores, local_attn_scores), dim=-1)
# softmax over knns
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(local_value.dtype)
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
if log_attn_scores:
# (bs, seq_len, num_attention_heads, knn) probabilities
# curate (x,y) pairs
# where x is attention weight, y is accuracy of retrieved token
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
key_labels = knn_labels // seq_len
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
bin_width = 0.05
# Calculate the number of bins
num_bins = int(1 / bin_width)
# Create empty lists for storing bin probabilities and accuracies
bin_probabilities = []
bin_accuracies = []
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
correct_keys = correct_keys.reshape(-1).tolist()
# Iterate over each bin
for i in range(num_bins):
bin_lower = i * bin_width
bin_upper = (i + 1) * bin_width
# Filter data points within the current bin range
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
# Calculate accuracy for the bin
total = len(bin_x_values)
correct = sum(bin_y_values)
accuracy = correct / total if total > 0 else 0
# Store the probability and accuracy for the bin
bin_probabilities.append((bin_lower + bin_upper) / 2)
bin_accuracies.append(accuracy)
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
if log_attn_scores:
# this def won't work well on multi-gpu machines
num_attention_heads = mem_attn_weights.size(1)
for head in range(num_attention_heads):
mem_attn_score_per_head = mem_attn_weights[:, head].reshape(-1)
mem_flat = mem_attn_score_per_head.clone().detach().cpu()
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
mem_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
plt.title(f"mem_attn_score_{head}")
# set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000))
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
plt.close()
local_attn_scores_per_head = local_attn_weights[:, head].reshape(-1)
local_flat = local_attn_scores_per_head.clone().detach().cpu()
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
local_bins = torch.linspace(0, 1, steps=20 + 1)
plt.stairs(local_hist.tolist(), local_bins.tolist())
plt.title(f"local_attn_score_{head}")
# set arbitrarily but we want to see those peaks!!
plt.ylim((0, 1000))
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
plt.close()
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
attn_output = torch.einsum("bshn, bhsnd-> bhsd", attn_scores, value)
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
local_attn_output = torch.matmul(local_attn_weights, local_value)
# TODO: do we need flamingo style gating
# of output_gate.tanh * attn_output
attn_output = mem_attn_output + local_attn_output
return attn_output
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
@ -361,9 +521,11 @@ class LetheAttention(nn.Module):
query,
key.transpose(1, 2),
beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
if self.memory:
attn_scores = attn_scores * self.scale.exp()
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
@ -413,7 +575,7 @@ class RotaryEmbedding(torch.nn.Module):
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
return self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype), self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype)
def rotate_half(x):
@ -448,12 +610,12 @@ class LetheMLP(nn.Module):
class LetheLayer(nn.Module):
def __init__(self, config, memory_attention=False, index=None):
def __init__(self, config, memory_attention=False, index=None, tracker=None):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index)
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
self.mlp = LetheMLP(config)
def forward(
@ -465,7 +627,9 @@ class LetheLayer(nn.Module):
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_mem_attn: Optional[bool] = True,
log_attn_scores: Optional[bool] = False,
step: Optional[int] = None,
save_kv: Optional[bool] = True
):
ln_hidden_states = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention(
@ -476,7 +640,9 @@ class LetheLayer(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
use_mem_attn=use_mem_attn,
log_attn_scores=log_attn_scores,
step=step,
save_kv=save_kv,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
@ -503,7 +669,7 @@ class LetheLayer(nn.Module):
class LetheModel(LethePreTrainedModel):
def __init__(self, config, index):
def __init__(self, config, index, tracker=None):
super().__init__(config)
self.config = config
@ -511,7 +677,8 @@ class LetheModel(LethePreTrainedModel):
self.layers = nn.ModuleList([LetheLayer(config,
memory_attention=i+1 == config.memory_attn_layer,
index=index if i+1 == config.memory_attn_layer else None)
index=index if i+1 == config.memory_attn_layer else None,
tracker=tracker)
for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -538,7 +705,9 @@ class LetheModel(LethePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True,
log_attn_scores: Optional[bool] = False,
step: Optional[int] = None,
save_kv: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
@ -631,7 +800,7 @@ class LetheModel(LethePreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
return module(*inputs, use_cache, None, output_attentions, log_attn_scores, step, save_kv)
return custom_forward
@ -651,7 +820,9 @@ class LetheModel(LethePreTrainedModel):
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
use_mem_attn=use_mem_attn,
log_attn_scores=log_attn_scores,
step=step,
save_kv=save_kv,
)
hidden_states = outputs[0]
if use_cache is True:
@ -678,10 +849,10 @@ class LetheModel(LethePreTrainedModel):
class LetheForCausalLM(LethePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, index):
def __init__(self, config, index, tracker=None):
super().__init__(config)
self.gpt_neox = LetheModel(config, index)
self.gpt_neox = LetheModel(config, index, tracker=tracker)
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hidden_size = config.hidden_size
@ -709,7 +880,9 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_mem_attn: Optional[bool] = True,
log_attn_scores: Optional[bool] = None,
step: Optional[int] = None,
save_kv: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
@ -763,7 +936,9 @@ class LetheForCausalLM(LethePreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_mem_attn=use_mem_attn
log_attn_scores=log_attn_scores,
step=step,
save_kv=save_kv,
)
hidden_states = outputs[0]

View File

@ -0,0 +1,21 @@
import time
import numpy as np
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
index = MemoryIndex(256,
575000,
8
)
keys = np.random.randn(32, 8, 1024, 256)
values = np.random.randn(32, 8, 1024, 256)
start = time.time()
index.add(keys, values)
print(f"index.add time: {time.time() - start}")
print(index.key_indices[0].index.ntotal)
queries = np.random.randn(32, 8, 1024, 256)
start = time.time()
index.knn_query(queries, k=32)
print(f"index.knn_query time: {time.time() - start}")

View File

@ -23,8 +23,9 @@ print("loading model")
dimension = config.max_position_embeddings * config.hidden_size
head_size = config.hidden_size // config.num_attention_heads
index = MemoryIndex(head_size,
500_000,
config.num_attention_heads
5_000_000,
# 2 since multi-query attention and storing one each for key and value
config.num_attention_heads,
)
model = LetheForCausalLM(config, index)
model.to("cuda:0")
@ -69,7 +70,7 @@ with torch.no_grad():
for chunk_start in tqdm(range(0, memories.shape[0], 32)):
chunk_end = min(memories.shape[0], chunk_start + 32)
mem_chunk = memories[chunk_start:chunk_end].to(model.device)
model(input_ids=mem_chunk, labels=None)
model(input_ids=mem_chunk, labels=None,)
model.train()

View File

@ -0,0 +1,2 @@
from .configuration_pythia_retro import PythiaRetroConfig
from .modeling_pythia_retro import PythiaRetroForCausalLM

View File

@ -1,4 +1,5 @@
import os
import torch.nn.functional as F
from transformers import AutoTokenizer, get_scheduler, AutoConfig
import torch
from torch.optim import AdamW
@ -12,6 +13,8 @@ from tqdm import tqdm
from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
import wandb
import pyarrow as pa
from pyarrow import feather
torch.backends.cuda.matmul.allow_tf32 = True
@ -22,39 +25,58 @@ def format_metrics(metrics, split, prefix=""):
return log
def evaluate(model, config, val_dataloader, main_process=False):
def calculate_per_example_loss(logits, labels):
lm_logits = logits[:, :-1, :].contiguous()
lm_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
# return tensor of shape (B,) where B is the batch size
return loss.cpu().tolist()
def evaluate(model, index, config, val_dataloader, main_process=False):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
head_size = model.config.hidden_size // model.config.num_attention_heads
index = MemoryIndex(head_size,
config["num_memories_per_index"],
model.config.num_attention_heads
)
ids = []
losses = []
with torch.no_grad():
for batch in tqdm(val_dataloader, disable=not main_process):
batch["id"] = batch["id"].detach().cpu()
memories = batch["retrieved_context"]
# need to set to eval so we don't do mem attn as it's slow
model.eval()
memories = memories[:, :config["num_neighbors_to_store"], :]
memories = memories.reshape(-1, memories.shape[-1])
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
model(input_ids=mem_chunk)
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs,
labels=qa_labels,
)
del memories
torch.cuda.empty_cache()
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
val_loss.update(loss_values["loss"])
index.reset()
val_loss.update(loss_values["loss"])
return val_loss
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
losses.extend(per_example_loss)
ids.extend(batch["id"].tolist())
ids = pa.array(ids)
losses = pa.array(losses)
schema = pa.schema([("loss", pa.float64()), ("id", pa.int32())])
table = pa.Table.from_arrays([losses, ids], schema=schema)
return val_loss, table
def train(accelerator, config):
@ -72,6 +94,7 @@ def train(accelerator, config):
with accelerator.main_process_first():
train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer)
if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
@ -97,6 +120,7 @@ def train(accelerator, config):
memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index,
tracker=accelerator.get_tracker("wandb"),
)
@ -135,16 +159,15 @@ def train(accelerator, config):
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader, scheduler
)
scheduler = True
use_scheduler = True
else:
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader
)
scheduler = False
use_scheduler = False
# setup for saving training states in case preemption
if scheduler:
if use_scheduler:
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
@ -167,24 +190,33 @@ def train(accelerator, config):
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
curr_step = step + epoch * len(train_dataloader)
memories = batch["retrieved_context"]
memories = memories[:, :config["num_neighbors_to_store"], :]
memories = memories.reshape(-1, memories.shape[-1])
# need to set to eval so we don't do mem attn as it's slow
model.eval()
with torch.no_grad():
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + 32)
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
model(input_ids=mem_chunk)
del memories
torch.cuda.empty_cache()
model.train()
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs,
labels=qa_labels,
log_attn_scores=True,
step=curr_step,
save_kv=False,
)
loss = outputs.loss
if config["wandb"]:
accelerator.log({"loss": loss}, step=curr_step)
index.reset()
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
@ -192,6 +224,8 @@ def train(accelerator, config):
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# !! don't reset index until after backwards pass
index.reset()
# get gradient norm of all params
# log LR in case something weird happens
@ -202,15 +236,25 @@ def train(accelerator, config):
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
if scheduler:
if use_scheduler:
scheduler.step()
optimizer.zero_grad()
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/step_{step}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, config, val_dataloader, main_process=main_process)
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
local_rank = accelerator.process_index
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
log_train = {
"train_loss": train_loss.compute()

View File

@ -1,4 +1,5 @@
import torch.distributed as dist
from contextlib import contextmanager
def rank0_print(msg):
@ -7,3 +8,20 @@ def rank0_print(msg):
print(msg)
else:
print(msg)
@contextmanager
def main_process_first(is_main):
yield from _goes_first(is_main)
def _goes_first(is_main):
if not is_main:
dist.barrier()
yield
if is_main:
dist.barrier()