mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-16 08:26:04 +00:00
fix: wip mem xf
This commit is contained in:
parent
3677935ce8
commit
55fef489ad
23
configs/eval/evaluate_lethe.yaml
Normal file
23
configs/eval/evaluate_lethe.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
|
||||||
|
tokenizer_name: "EleutherAI/pythia-1b"
|
||||||
|
version: null
|
||||||
|
gradient_checkpointing: false
|
||||||
|
memory_attn_layer: 12
|
||||||
|
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
|
||||||
|
# dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
|
||||||
|
max_length: 1024
|
||||||
|
batch_size: 1
|
||||||
|
pct_test: 0.05
|
||||||
|
q_column: "question"
|
||||||
|
a_column: "answer"
|
||||||
|
context_column: "text"
|
||||||
|
num_memories_per_index: 2000000
|
||||||
|
num_neighbors_to_retrieve: 2
|
||||||
|
num_neighbors_to_store: 1
|
||||||
|
mem_chunk_size: 64
|
19
configs/inference/synth_data.yaml
Normal file
19
configs/inference/synth_data.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
model_name: "nomic-ai/gpt4all-j"
|
||||||
|
revision: "v1.3-groovy"
|
||||||
|
tokenizer_name: "nomic-ai/gpt4all-j"
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "nomic-ai/cohere-wiki-sbert"
|
||||||
|
batch_size: 32
|
||||||
|
output_path: "synth_qa_pairs"
|
||||||
|
|
||||||
|
# generation
|
||||||
|
max_new_tokens: 75
|
||||||
|
max_generations: 200000
|
||||||
|
|
||||||
|
save_every: 1000
|
||||||
|
|
||||||
|
|
||||||
|
seed: 42
|
@ -10,35 +10,36 @@ memory_attn_layer: 12
|
|||||||
# dataset
|
# dataset
|
||||||
streaming: false
|
streaming: false
|
||||||
num_proc: 64
|
num_proc: 64
|
||||||
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
|
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
|
||||||
max_length: 1024
|
max_length: 1024
|
||||||
batch_size: 64
|
batch_size: 32
|
||||||
pct_test: 0.05
|
pct_test: 0.05
|
||||||
q_column: "question"
|
q_column: "question"
|
||||||
a_column: "answers"
|
a_column: "answer"
|
||||||
context_column: "neighbor_text"
|
context_column: "text"
|
||||||
num_memories_per_index: 500000
|
num_memories_per_index: 2000000
|
||||||
num_neighbors_to_retrieve: 32
|
num_neighbors_to_retrieve: 2
|
||||||
|
num_neighbors_to_store: 1
|
||||||
mem_chunk_size: 64
|
mem_chunk_size: 64
|
||||||
|
|
||||||
# train dynamics
|
# train dynamics
|
||||||
lr: 1.0e-4
|
lr: 1.0e-5
|
||||||
min_lr: 0
|
min_lr: 0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
eval_every: 100
|
eval_every: 100
|
||||||
save_every: -1
|
save_every: 100
|
||||||
log_grads_every: 100
|
log_grads_every: 100
|
||||||
log_lr_every: 10
|
log_lr_every: 10
|
||||||
output_dir: "ckpts/mem_attn"
|
output_dir: "ckpts/mem_attn_no_cosine_sim"
|
||||||
checkpoint: null
|
checkpoint: null
|
||||||
lora: false
|
lora: false
|
||||||
warmup_steps: 500
|
warmup_steps: 200
|
||||||
num_epochs: 5
|
num_epochs: 5
|
||||||
debug: false
|
debug: false
|
||||||
scheduler: false
|
scheduler: false
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
wandb: false
|
wandb: true
|
||||||
wandb_entity: gpt4all
|
wandb_entity: gpt4all
|
||||||
wandb_project_name: mem_attn
|
wandb_project_name: mem_attn
|
||||||
seed: 42
|
seed: 42
|
||||||
|
43
configs/train/pretrain_minipile.yaml
Normal file
43
configs/train/pretrain_minipile.yaml
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "EleutherAI/pythia-1b"
|
||||||
|
tokenizer_name: "EleutherAI/pythia-1b"
|
||||||
|
version: null
|
||||||
|
gradient_checkpointing: true
|
||||||
|
save_name: "nomic-ai/minipille"
|
||||||
|
push_to_hub: false
|
||||||
|
memory_attn_layer: 12
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "JeanKaddour/minipile"
|
||||||
|
max_length: 2048
|
||||||
|
batch_size: 64
|
||||||
|
pct_test: 0.05
|
||||||
|
num_memories_per_index: 5000000
|
||||||
|
mem_chunk_size: 512
|
||||||
|
num_chunks: 10
|
||||||
|
num_neighbors_to_retrieve: 32
|
||||||
|
|
||||||
|
# train dynamics
|
||||||
|
lr: 1.0e-4
|
||||||
|
min_lr: 0
|
||||||
|
weight_decay: 0.0
|
||||||
|
eval_every: 100
|
||||||
|
save_every: -1
|
||||||
|
log_grads_every: 100
|
||||||
|
log_lr_every: 10
|
||||||
|
output_dir: "ckpts/minipile"
|
||||||
|
checkpoint: null
|
||||||
|
lora: false
|
||||||
|
warmup_steps: 500
|
||||||
|
num_epochs: 5
|
||||||
|
debug: false
|
||||||
|
scheduler: false
|
||||||
|
|
||||||
|
# logging
|
||||||
|
wandb: true
|
||||||
|
wandb_entity: gpt4all
|
||||||
|
wandb_project_name: minipile
|
||||||
|
seed: 42
|
||||||
|
|
@ -85,6 +85,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
|||||||
|
|
||||||
question_col = config["q_column"]
|
question_col = config["q_column"]
|
||||||
answer_col = config["a_column"]
|
answer_col = config["a_column"]
|
||||||
|
context_col = config["context_column"]
|
||||||
|
|
||||||
if config["streaming"] is False:
|
if config["streaming"] is False:
|
||||||
kwargs = {"num_proc": config["num_proc"]}
|
kwargs = {"num_proc": config["num_proc"]}
|
||||||
@ -96,7 +97,8 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
|||||||
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
|
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
|
||||||
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
|
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
|
||||||
# a list of the answer
|
# a list of the answer
|
||||||
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
|
dataset = dataset.map(lambda ele: {answer_col: [t.strip() for t in ele[answer_col]]}, batched=True)
|
||||||
|
# dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
|
||||||
@ -106,19 +108,73 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
|||||||
|
|
||||||
# tokenize contexts for each example
|
# tokenize contexts for each example
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
lambda ele: {"retrieved_context": tokenizer(ele["context"],
|
lambda ele: {"retrieved_context": tokenizer([ele[context_col]],
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True)["input_ids"]},
|
truncation=True)["input_ids"]},
|
||||||
batched=True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
columns_to_keep = ["input_ids", "labels", "retrieved_context"]
|
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
|
||||||
|
|
||||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||||
dataset = dataset.remove_columns(col_names_to_rm)
|
dataset = dataset.remove_columns(col_names_to_rm)
|
||||||
|
|
||||||
|
if split_dataset:
|
||||||
|
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
|
||||||
|
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset.remove_columns("id"),
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
collate_fn=DefaultDataCollator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
collate_fn=DefaultDataCollator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataloader, val_dataloader
|
||||||
|
|
||||||
|
else:
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
collate_fn=DefaultDataCollator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def load_memory_pretraining_data(config, tokenizer, split="train", split_dataset=True):
|
||||||
|
dataset_path = config["dataset_path"]
|
||||||
|
|
||||||
|
if os.path.exists(dataset_path):
|
||||||
|
dataset = Dataset.load_from_disk(dataset_path)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(dataset_path, split=split)
|
||||||
|
|
||||||
|
if config["streaming"] is False:
|
||||||
|
kwargs = {"num_proc": config["num_proc"]}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# e.g. 512 * 10 = 5120 sequence length split up
|
||||||
|
max_length = config["mem_chunk_size"] * config["num_chunks"]
|
||||||
|
dataset = dataset.map(lambda ele: tokenizer(ele["text"], padding="max_length", truncation=True, max_length=max_length),
|
||||||
|
batched=True, **kwargs)
|
||||||
|
|
||||||
|
dataset = dataset.map(lambda x: {"labels": x["input_ids"]}, batched=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
columns_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||||
|
|
||||||
|
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||||
|
dataset = dataset.remove_columns(col_names_to_rm)
|
||||||
|
|
||||||
|
# we can shuffle since the docs are in one row not split across rows
|
||||||
if split_dataset:
|
if split_dataset:
|
||||||
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
|
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
|
||||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||||
|
116
gpt4all/eval/eval_squad_atlas_map.py
Normal file
116
gpt4all/eval/eval_squad_atlas_map.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from gpt4all.models import LetheForCausalLM
|
||||||
|
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||||
|
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
||||||
|
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||||
|
from gpt4all.utils.read import read_config
|
||||||
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from tqdm import tqdm
|
||||||
|
from nomic import atlas
|
||||||
|
from datasets import load_from_disk
|
||||||
|
|
||||||
|
|
||||||
|
def calc_loss_per_item(logits, labels):
|
||||||
|
lm_logits = logits[:, :-1, :].contiguous()
|
||||||
|
lm_labels = labels[:, 1:].contiguous()
|
||||||
|
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
|
||||||
|
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
|
||||||
|
|
||||||
|
# return tensor of shape (B,) where B is the batch size
|
||||||
|
return loss.cpu().tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
|
||||||
|
num_new_tokens = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
while True:
|
||||||
|
if num_new_tokens >= max_new_tokens:
|
||||||
|
break
|
||||||
|
outputs = model(input_ids, save_kv=False)
|
||||||
|
|
||||||
|
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
|
||||||
|
|
||||||
|
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
|
||||||
|
num_new_tokens += 1
|
||||||
|
|
||||||
|
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
|
||||||
|
break
|
||||||
|
|
||||||
|
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = read_config(args.config)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"], model_max_length=config["max_length"])
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
dataloader = load_memory_augmented_data(config, tokenizer, split_dataset=False)
|
||||||
|
|
||||||
|
dataset = load_from_disk(config["dataset_path"])
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model_config = AutoConfig.from_pretrained(config["model_name"])
|
||||||
|
|
||||||
|
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||||
|
index = MemoryIndex(head_size,
|
||||||
|
config["num_memories_per_index"],
|
||||||
|
model_config.num_attention_heads
|
||||||
|
)
|
||||||
|
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
||||||
|
revision=config['version'] if 'version' in config else None,
|
||||||
|
memory_attn_layer=config["memory_attn_layer"],
|
||||||
|
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
||||||
|
index=index,
|
||||||
|
).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Evaluate the model on the SQUAD dataset
|
||||||
|
losses = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, batch in enumerate(tqdm(dataloader)):
|
||||||
|
memories = batch["retrieved_context"]
|
||||||
|
memories = memories[:, :config["num_neighbors_to_store"], :]
|
||||||
|
memories = memories.reshape(-1, memories.shape[-1])
|
||||||
|
|
||||||
|
# need to set to eval so we don't do mem attn as it's slow
|
||||||
|
with torch.no_grad():
|
||||||
|
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
||||||
|
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
||||||
|
mem_chunk = memories[chunk_start:chunk_end]
|
||||||
|
model(input_ids=mem_chunk.to(device))
|
||||||
|
|
||||||
|
del memories
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
qa_inputs = batch["input_ids"]
|
||||||
|
qa_labels = batch["labels"]
|
||||||
|
for i in range(qa_inputs.shape[0]):
|
||||||
|
inputs = qa_inputs[i].to(device)
|
||||||
|
labels = qa_labels[i].to(device)
|
||||||
|
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
||||||
|
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
|
||||||
|
print(tokenizer.decode(inputs, skip_special_tokens=True))
|
||||||
|
|
||||||
|
|
||||||
|
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
|
||||||
|
# losses.extend(batch_loss)
|
||||||
|
index.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dataset = dataset.add_column("loss", losses)
|
||||||
|
|
||||||
|
dataset.save_to_disk("eval_squad_atlas_map")
|
||||||
|
|
||||||
|
|
||||||
|
|
57
gpt4all/inference/combine_synth_data.py
Normal file
57
gpt4all/inference/combine_synth_data.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import glob
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from datasets import Dataset, load_from_disk, concatenate_datasets
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
|
||||||
|
|
||||||
|
def load_synth_data(data_dir):
|
||||||
|
files = glob.glob(data_dir + "/*")
|
||||||
|
|
||||||
|
ds = concatenate_datasets([load_from_disk(f) for f in files])
|
||||||
|
table = ds.data.table
|
||||||
|
|
||||||
|
filtered = table.filter(table["valid"])
|
||||||
|
ds = Dataset.from_dict(filtered.to_pydict())
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def remove_prompt(examples):
|
||||||
|
outputs = {"text": [], "generated": [], "question": [], "answer": []}
|
||||||
|
for context, generated in zip(examples["text"], examples["generated"]):
|
||||||
|
prompt_w_ctx = PROMPT.format(context)
|
||||||
|
gen_wo_ctx = generated[len(prompt_w_ctx):]
|
||||||
|
|
||||||
|
assert prompt_w_ctx not in gen_wo_ctx
|
||||||
|
|
||||||
|
question = gen_wo_ctx.split("Answer:")[0].replace("Question:", "").strip()
|
||||||
|
answer = gen_wo_ctx.split("Answer:")[1].strip()
|
||||||
|
|
||||||
|
outputs["text"].append(context)
|
||||||
|
outputs["generated"].append(gen_wo_ctx)
|
||||||
|
outputs["question"].append(question)
|
||||||
|
outputs["answer"].append(answer)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def combine_synth_data(data_dir):
|
||||||
|
ds = load_synth_data(data_dir)
|
||||||
|
|
||||||
|
ds = ds.map(lambda ele: remove_prompt(ele), batched=True, num_proc=64)
|
||||||
|
|
||||||
|
ds.save_to_disk(f"synth_data_combined_{len(ds)/1000:.0f}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--dataset_dir", type=str, required=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
combine_synth_data(args.dataset_dir)
|
||||||
|
|
||||||
|
|
149
gpt4all/inference/generate_synth_data.py
Normal file
149
gpt4all/inference/generate_synth_data.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, DefaultDataCollator
|
||||||
|
from gpt4all.utils.read import read_config
|
||||||
|
from gpt4all.utils.distributed_utils import rank0_print, main_process_first
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow as pa
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(config, tokenizer, num_processes, local_rank):
|
||||||
|
dataset = load_dataset(config["dataset_path"], split="train")
|
||||||
|
dataset = dataset.remove_columns("embedding")
|
||||||
|
|
||||||
|
shuffled = dataset.shuffle(seed=config["seed"])
|
||||||
|
indices = shuffled[:config["max_generations"]]["id"]
|
||||||
|
|
||||||
|
table = dataset.data
|
||||||
|
mask = pc.is_in(table["id"], value_set=pa.array(indices, pa.int32()))
|
||||||
|
filtered_table = table.filter(mask)
|
||||||
|
|
||||||
|
# convert from pyarrow to Dataset
|
||||||
|
orig_dataset = Dataset.from_dict(filtered_table.to_pydict())
|
||||||
|
|
||||||
|
dataset = orig_dataset.map(lambda ele: {"prompted_text": [PROMPT.format(context) for context in ele["text"]]},
|
||||||
|
batched=True,
|
||||||
|
num_proc=config["num_proc"] if "num_proc" in config else None)
|
||||||
|
|
||||||
|
dataset = dataset.map(lambda ele: {"prompt_len": [len(prompt) for prompt in ele["prompted_text"]]}, batched=True,
|
||||||
|
num_proc=config["num_proc"] if "num_proc" in config else None)
|
||||||
|
|
||||||
|
dataset = dataset.sort("prompt_len")
|
||||||
|
dataset = dataset.map(lambda ele: tokenizer(ele["prompted_text"], return_tensors="pt", padding="longest", truncation=True,
|
||||||
|
max_length=tokenizer.model_max_length - config["max_new_tokens"]), batched=True,
|
||||||
|
batch_size=num_processes * config["batch_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
columns_to_keep = ["id", "input_ids", "attention_mask"]
|
||||||
|
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||||
|
|
||||||
|
dataset = dataset.remove_columns(col_names_to_rm)
|
||||||
|
|
||||||
|
sampler = DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=True,
|
||||||
|
num_replicas=num_processes,
|
||||||
|
rank=local_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
collate_fn=DefaultDataCollator(),
|
||||||
|
sampler=sampler,
|
||||||
|
drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader, orig_dataset
|
||||||
|
|
||||||
|
def generate_data(config):
|
||||||
|
set_seed(config["seed"])
|
||||||
|
|
||||||
|
rank0_print(f"World size: {dist.get_world_size()}")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||||
|
# since we're doing generation, pad left for autoregressive generation
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
num_processes = dist.get_world_size()
|
||||||
|
local_rank = dist.get_rank()
|
||||||
|
|
||||||
|
dataloader, dataset = prepare_data(config, tokenizer, num_processes, local_rank)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
print(dataset[:10]["id"])
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
|
revision=config["revision"] if "revision" in config else None,
|
||||||
|
use_cache=True,
|
||||||
|
torch_dtype=torch.bfloat16,)
|
||||||
|
model.to(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
synth_data = []
|
||||||
|
valid = []
|
||||||
|
ids = []
|
||||||
|
total_valid = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, batch in enumerate(tqdm(dataloader, disable=local_rank != 0)):
|
||||||
|
# keep this simple for now, can add temperature and other sampling techniques later
|
||||||
|
generated = model.generate(batch["input_ids"].to(model.device),
|
||||||
|
attention_mask=batch["attention_mask"].to(model.device),
|
||||||
|
max_new_tokens=config["max_new_tokens"])
|
||||||
|
|
||||||
|
decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
||||||
|
num_valid = ["\nQuestion:" in t and "\nAnswer:" in t for t in decoded]
|
||||||
|
rank0_print(f"Num valid: {sum(num_valid)/ len(num_valid):.2f}")
|
||||||
|
total_valid += sum(num_valid)
|
||||||
|
|
||||||
|
synth_data.extend(decoded)
|
||||||
|
valid.extend(num_valid)
|
||||||
|
ids.extend(batch["id"].tolist())
|
||||||
|
|
||||||
|
if i > 0 and i % config["save_every"] == 0:
|
||||||
|
table = dataset.data.table
|
||||||
|
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
|
||||||
|
filtered_table = table.filter(mask)
|
||||||
|
|
||||||
|
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
|
||||||
|
joined = filtered_table.join(chunk_table, "id")
|
||||||
|
curr_dataset = Dataset.from_dict(joined.to_pydict())
|
||||||
|
curr_dataset.save_to_disk(f'{config["output_path"]}/chunk_{i}_rank_{local_rank}')
|
||||||
|
|
||||||
|
table = dataset.data.table
|
||||||
|
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
|
||||||
|
filtered_table = table.filter(mask)
|
||||||
|
|
||||||
|
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
|
||||||
|
joined = filtered_table.join(chunk_table, "id")
|
||||||
|
full_dataset = Dataset.from_dict(joined.to_pydict())
|
||||||
|
full_dataset.save_to_disk(f'{config["output_path"]}_{config["max_generations"]}_rank_{local_rank}')
|
||||||
|
|
||||||
|
rank0_print(f"Total valid: {total_valid}/{config['max_generations']}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = read_config(args.config)
|
||||||
|
|
||||||
|
generate_data(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# parse arguments by reading in a config
|
||||||
|
main()
|
@ -111,6 +111,7 @@ class LetheConfig(PretrainedConfig):
|
|||||||
memory_attn_layer=9,
|
memory_attn_layer=9,
|
||||||
num_neighbors_to_retrieve=32,
|
num_neighbors_to_retrieve=32,
|
||||||
num_neighbors_stored=128,
|
num_neighbors_stored=128,
|
||||||
|
attn_scale_init=20.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
@ -133,4 +134,5 @@ class LetheConfig(PretrainedConfig):
|
|||||||
# index of cross attention layer to add
|
# index of cross attention layer to add
|
||||||
self.memory_attn_layer = memory_attn_layer
|
self.memory_attn_layer = memory_attn_layer
|
||||||
self.num_neighbors_to_retrieve = num_neighbors_to_retrieve
|
self.num_neighbors_to_retrieve = num_neighbors_to_retrieve
|
||||||
self.num_neighbors_stored = num_neighbors_stored
|
self.num_neighbors_stored = num_neighbors_stored
|
||||||
|
self.attn_scale_init = attn_scale_init
|
@ -12,8 +12,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch PythiaSeek model."""
|
""" PyTorch Lethe model."""
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -29,8 +33,9 @@ from transformers.modeling_outputs import (
|
|||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from gpt4all.models.lethe import LetheConfig
|
from gpt4all.models.lethe import LetheConfig
|
||||||
import hnswlib
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import faiss
|
||||||
|
import faiss.contrib.torch_utils
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -40,17 +45,18 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
"EleutherAI/gpt-neox-20b",
|
"EleutherAI/gpt-neox-20b",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
|
|
||||||
# TODO: do we need to implement masking for the dense vectors we pull from?
|
|
||||||
# TODO: i think phil is using a memmapped database to pull out rather than using the index
|
|
||||||
|
|
||||||
|
|
||||||
class HNSWIndex:
|
class HNSWIndex:
|
||||||
def __init__(self, max_memories, dimension):
|
def __init__(self, max_memories, dimension):
|
||||||
# num_memories will be batch size * num_neighbors
|
# num_memories will be batch size * num_neighbors
|
||||||
# can memmap this too like
|
# can memmap this too like
|
||||||
self.index = hnswlib.Index(space="l2", dim=dimension)
|
self.index = faiss.IndexHNSWFlat(dimension, 16, faiss.METRIC_INNER_PRODUCT)
|
||||||
self.index.init_index(max_elements=max_memories, ef_construction=50, M=16)
|
# taking params from: https://www.pinecone.io/learn/vector-indexes/#hnsw-implementation
|
||||||
|
# and https://www.pinecone.io/learn/hnsw/#hnsw-performance
|
||||||
|
# seems like efConstruction dictates how long the index takes to build
|
||||||
|
# and efSearch and M (second arg to faiss.Index) dictates how long it takes to search
|
||||||
|
self.index.hnsw.efConstruction = 16
|
||||||
|
self.index.hnsw.efSearch = 32
|
||||||
self.max_memories = max_memories
|
self.max_memories = max_memories
|
||||||
self.dimension = dimension
|
self.dimension = dimension
|
||||||
|
|
||||||
@ -60,45 +66,65 @@ class HNSWIndex:
|
|||||||
|
|
||||||
def query(self, query, k=1):
|
def query(self, query, k=1):
|
||||||
# hack what should we do here?
|
# hack what should we do here?
|
||||||
if self.index.get_current_count() == 0:
|
if self.index.ntotal == 0:
|
||||||
return np.ones((query.shape[0], k, query.shape[1]), dtype=np.float32)
|
return np.ones((query.shape[0], k), dtype=np.int32)
|
||||||
|
|
||||||
assert query.ndim == 2
|
_, labels = self.index.search(np.ascontiguousarray(query), k=k)
|
||||||
bs_seq_len, _ = query.shape
|
|
||||||
|
|
||||||
labels, _ = self.index.knn_query(query, k=k)
|
return labels
|
||||||
neighbors = torch.tensor(self.index.get_items(labels.reshape(-1)))
|
|
||||||
neighbors = neighbors.reshape((bs_seq_len, k, query.shape[1]))
|
|
||||||
|
|
||||||
assert neighbors.ndim == 3
|
|
||||||
assert neighbors.shape[0] == bs_seq_len
|
|
||||||
|
|
||||||
return neighbors
|
|
||||||
|
|
||||||
def add(self, memories):
|
def add(self, memories):
|
||||||
assert memories.ndim == 2
|
return self.index.add(np.ascontiguousarray(memories))
|
||||||
bs_seq_len, _ = memories.shape
|
|
||||||
|
|
||||||
ids = np.arange(self.idx_offset, self.idx_offset + bs_seq_len)
|
|
||||||
|
|
||||||
self.index.add_items(memories, ids)
|
|
||||||
|
|
||||||
self.idx_offset += bs_seq_len
|
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.index = hnswlib.Index(space="l2", dim=self.dimension)
|
self.index.reset()
|
||||||
self.index.init_index(max_elements=self.max_memories, ef_construction=50, M=16)
|
|
||||||
|
|
||||||
|
class NumpyKNNIndex:
|
||||||
|
def __init__(self, max_memories, dimension):
|
||||||
|
# num_memories will be batch size * num_neighbors
|
||||||
|
# can memmap this too like
|
||||||
|
self.index = np.zeros((max_memories, dimension), dtype=np.float32)
|
||||||
|
self.max_memories = max_memories
|
||||||
|
self.dimension = dimension
|
||||||
|
|
||||||
|
# if we want to allow for insertion of len(elements) > max_memories
|
||||||
|
# we need to figure out a way to get the most recent memories
|
||||||
|
self.idx_offset = 0
|
||||||
|
|
||||||
|
def query(self, query, k=1):
|
||||||
|
# hack what should we do here?
|
||||||
|
if self.index.sum() == 0:
|
||||||
|
return np.ones((query.shape[0], k), dtype=np.int32)
|
||||||
|
|
||||||
|
dots = query.dot(self.index[:self.idx_offset].T)
|
||||||
|
labels = np.argsort(dots, axis=1)[:, -k:]
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def add(self, memories):
|
||||||
|
self.index[self.idx_offset:self.idx_offset + memories.shape[0]] = memories
|
||||||
|
self.idx_offset += memories.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.index.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryIndex:
|
class MemoryIndex:
|
||||||
def __init__(self, hidden_dim, num_mems, nheads):
|
def __init__(self, hidden_dim, num_mems, nheads):
|
||||||
# we store an index for each k/v for each head
|
|
||||||
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
|
||||||
self.value_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
|
||||||
self.nheads = nheads
|
self.nheads = nheads
|
||||||
|
|
||||||
|
# NOTE: we are storing kv pairs, instead indices for both keys and values
|
||||||
|
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
||||||
|
|
||||||
|
shape = (num_mems, nheads, 2, hidden_dim)
|
||||||
|
self.kv_pairs = np.zeros(shape, dtype=np.float32)
|
||||||
|
self.idx_offset = 0
|
||||||
|
|
||||||
def add(self, keys, values):
|
def add(self, keys, values):
|
||||||
# k/v are (bs, num_attention_heads, seq_len, head_size)
|
# k/v are (bs, num_attention_heads, seq_len, head_size)
|
||||||
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
|
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
|
||||||
@ -106,22 +132,30 @@ class MemoryIndex:
|
|||||||
|
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
self.key_indices[head].add(reshaped_keys[:, head, :])
|
self.key_indices[head].add(reshaped_keys[:, head, :])
|
||||||
self.value_indices[head].add(reshaped_values[:, head, :])
|
|
||||||
|
|
||||||
|
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
|
||||||
|
|
||||||
|
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
|
||||||
|
raise ValueError("Not enough memory!")
|
||||||
|
|
||||||
|
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
|
||||||
|
self.idx_offset += kv_pairs.shape[0]
|
||||||
|
|
||||||
def knn_query(self, query, k=1):
|
def knn_query(self, query, k=1):
|
||||||
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
|
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
|
||||||
|
|
||||||
mem_keys = []
|
mem_keys = []
|
||||||
mem_values = []
|
mem_values = []
|
||||||
|
mem_indices = []
|
||||||
|
|
||||||
# this is prob so so slow
|
# we can prob make this better
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
knn_keys = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
|
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
|
||||||
knn_values = self.value_indices[head].query(reshaped_query[:, head, :], k=k)
|
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
|
||||||
|
|
||||||
mem_keys.append(knn_keys)
|
mem_keys.append(kv_pairs[:, :, 0, :])
|
||||||
mem_values.append(knn_values)
|
mem_values.append(kv_pairs[:, :, 1, :])
|
||||||
|
mem_indices.append(knn_indices)
|
||||||
|
|
||||||
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
|
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
|
||||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
# (bs, num_attention_heads, seq_len, k, head_size)
|
||||||
@ -131,13 +165,14 @@ class MemoryIndex:
|
|||||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
# (bs, num_attention_heads, seq_len, k, head_size)
|
||||||
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||||
|
|
||||||
return mem_keys, mem_values
|
return mem_keys, mem_values, np.stack(mem_indices, axis=1)
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
self.key_indices[head].reset()
|
self.key_indices[head].reset()
|
||||||
self.value_indices[head].reset()
|
|
||||||
|
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
class LethePreTrainedModel(PreTrainedModel):
|
class LethePreTrainedModel(PreTrainedModel):
|
||||||
@ -171,7 +206,7 @@ class LethePreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class LetheAttention(nn.Module):
|
class LetheAttention(nn.Module):
|
||||||
def __init__(self, config, memory_attention=False, index=None):
|
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -188,21 +223,24 @@ class LetheAttention(nn.Module):
|
|||||||
self.rotary_emb = RotaryEmbedding(
|
self.rotary_emb = RotaryEmbedding(
|
||||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
||||||
)
|
)
|
||||||
self.register_buffer(
|
if not memory_attention:
|
||||||
|
self.register_buffer(
|
||||||
"norm_factor",
|
"norm_factor",
|
||||||
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.memory = False
|
self.memory = False
|
||||||
|
|
||||||
if memory_attention:
|
if memory_attention:
|
||||||
|
self.scale = nn.Parameter(torch.ones(self.num_attention_heads, 1, 1) * math.log(config.attn_scale_init))
|
||||||
self.memory = True
|
self.memory = True
|
||||||
self.alpha = nn.Parameter(torch.zeros(self.num_attention_heads))
|
|
||||||
self.num_neighbors = config.num_neighbors_to_retrieve
|
self.num_neighbors = config.num_neighbors_to_retrieve
|
||||||
# for testing, just using np array since it's easy
|
# for testing, just using np array since it's easy
|
||||||
self.index = index
|
self.index = index
|
||||||
|
self.tracker = tracker
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -214,7 +252,9 @@ class LetheAttention(nn.Module):
|
|||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_mem_attn: Optional[bool] = True,
|
log_attn_scores: Optional[bool] = False,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
save_kv: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
has_layer_past = layer_past is not None
|
has_layer_past = layer_past is not None
|
||||||
|
|
||||||
@ -234,6 +274,12 @@ class LetheAttention(nn.Module):
|
|||||||
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
||||||
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# if self.memory:
|
||||||
|
if self.memory:
|
||||||
|
# QKNorm: https://arxiv.org/abs/2010.04245
|
||||||
|
query = F.normalize(query, dim=-1)
|
||||||
|
key = F.normalize(key, dim=-1)
|
||||||
|
|
||||||
# Compute rotary embeddings on rotary_ndims
|
# Compute rotary embeddings on rotary_ndims
|
||||||
query_rot = query[..., : self.rotary_ndims]
|
query_rot = query[..., : self.rotary_ndims]
|
||||||
query_pass = query[..., self.rotary_ndims :]
|
query_pass = query[..., self.rotary_ndims :]
|
||||||
@ -257,27 +303,38 @@ class LetheAttention(nn.Module):
|
|||||||
value = torch.cat((past_value, value), dim=-2)
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
present = (key, value) if use_cache else None
|
present = (key, value) if use_cache else None
|
||||||
|
|
||||||
# Compute attention
|
# memory attention
|
||||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
|
||||||
|
|
||||||
# TODO: need to do masking??
|
|
||||||
if self.memory:
|
if self.memory:
|
||||||
# get knns
|
if save_kv:
|
||||||
# since we do an eval batch w context before, let's not do the expensive step until we need to
|
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||||
# [batch, knn, num_attention_heads, seq_len, head_size]
|
|
||||||
if use_mem_attn:
|
|
||||||
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
|
||||||
mem_attn = self._mem_attn(query,
|
|
||||||
knn_keys.to(query.device),
|
|
||||||
knn_values.to(query.device),
|
|
||||||
attention_mask,
|
|
||||||
head_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
expanded_alpha = self.alpha[None, :, None, None]
|
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||||
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
|
if log_attn_scores:
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
seq_len = query.shape[-2]
|
||||||
|
|
||||||
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
key_labels = knn_labels // seq_len
|
||||||
|
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
|
||||||
|
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||||
|
# calculate the accuracy
|
||||||
|
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
|
||||||
|
|
||||||
|
self.tracker.log({"retrieved_acc": key_acc}, step=step)
|
||||||
|
|
||||||
|
attn_output = self._mem_attn(query,
|
||||||
|
knn_keys.to(query.device).to(value.dtype),
|
||||||
|
knn_values.to(query.device).to(value.dtype),
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
log_attn_scores=log_attn_scores,
|
||||||
|
step=step,
|
||||||
|
knn_labels=knn_labels,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Normal self-attention
|
||||||
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
|
||||||
# Reshape outputs
|
# Reshape outputs
|
||||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||||
@ -315,28 +372,131 @@ class LetheAttention(nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def _mem_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
def _mem_attn(self,
|
||||||
|
query,
|
||||||
|
knn_key,
|
||||||
|
knn_value,
|
||||||
|
local_key,
|
||||||
|
local_value,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
log_attn_scores=False,
|
||||||
|
step=None,
|
||||||
|
knn_labels=None):
|
||||||
|
# local self-attention
|
||||||
# q: [bs, num_attention_heads, seq_len, attn_head_size]
|
# q: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||||
# k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size]
|
# k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size]
|
||||||
|
query_length = query.size(-2)
|
||||||
|
key_length = local_key.size(-2)
|
||||||
|
|
||||||
attn_scores = torch.einsum("bhsd, bhsnd-> bshn", query, key)
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
|
||||||
attn_scores = attn_scores / self.norm_factor
|
|
||||||
|
|
||||||
# softmax over knns
|
local_attn_scores = torch.matmul(query, local_key.transpose(-1, -2))
|
||||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
scale = self.scale.exp()
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
|
||||||
|
local_attn_scores = local_attn_scores * scale
|
||||||
|
|
||||||
|
mask_value = torch.finfo(local_attn_scores.dtype).min
|
||||||
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||||
|
mask_value = torch.tensor(mask_value, dtype=local_attn_scores.dtype).to(local_attn_scores.device)
|
||||||
|
local_attn_scores = torch.where(causal_mask, local_attn_scores, mask_value)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
attn_scores = attn_scores + attention_mask
|
local_attn_scores = local_attn_scores + attention_mask
|
||||||
|
|
||||||
|
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
|
||||||
|
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
||||||
|
mem_attn_scores = mem_attn_scores * scale
|
||||||
|
|
||||||
|
attn_scores = torch.cat((mem_attn_scores, local_attn_scores), dim=-1)
|
||||||
|
|
||||||
|
# softmax over knns
|
||||||
|
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||||
|
attn_weights = attn_weights.to(local_value.dtype)
|
||||||
|
|
||||||
|
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
|
||||||
|
if log_attn_scores:
|
||||||
|
# (bs, seq_len, num_attention_heads, knn) probabilities
|
||||||
|
# curate (x,y) pairs
|
||||||
|
# where x is attention weight, y is accuracy of retrieved token
|
||||||
|
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
|
||||||
|
key_labels = knn_labels // seq_len
|
||||||
|
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
|
||||||
|
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||||
|
|
||||||
|
bin_width = 0.05
|
||||||
|
|
||||||
|
# Calculate the number of bins
|
||||||
|
num_bins = int(1 / bin_width)
|
||||||
|
|
||||||
|
# Create empty lists for storing bin probabilities and accuracies
|
||||||
|
bin_probabilities = []
|
||||||
|
bin_accuracies = []
|
||||||
|
|
||||||
|
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
|
||||||
|
correct_keys = correct_keys.reshape(-1).tolist()
|
||||||
|
|
||||||
|
# Iterate over each bin
|
||||||
|
for i in range(num_bins):
|
||||||
|
bin_lower = i * bin_width
|
||||||
|
bin_upper = (i + 1) * bin_width
|
||||||
|
|
||||||
|
# Filter data points within the current bin range
|
||||||
|
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
|
||||||
|
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
|
||||||
|
|
||||||
|
# Calculate accuracy for the bin
|
||||||
|
total = len(bin_x_values)
|
||||||
|
correct = sum(bin_y_values)
|
||||||
|
accuracy = correct / total if total > 0 else 0
|
||||||
|
|
||||||
|
# Store the probability and accuracy for the bin
|
||||||
|
bin_probabilities.append((bin_lower + bin_upper) / 2)
|
||||||
|
bin_accuracies.append(accuracy)
|
||||||
|
|
||||||
|
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
|
||||||
|
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
|
||||||
|
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
|
||||||
|
|
||||||
|
|
||||||
|
if log_attn_scores:
|
||||||
|
# this def won't work well on multi-gpu machines
|
||||||
|
num_attention_heads = mem_attn_weights.size(1)
|
||||||
|
for head in range(num_attention_heads):
|
||||||
|
mem_attn_score_per_head = mem_attn_weights[:, head].reshape(-1)
|
||||||
|
mem_flat = mem_attn_score_per_head.clone().detach().cpu()
|
||||||
|
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
|
||||||
|
mem_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||||
|
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
|
||||||
|
plt.title(f"mem_attn_score_{head}")
|
||||||
|
# set arbitrarily but we want to see those peaks!!
|
||||||
|
plt.ylim((0, 1000))
|
||||||
|
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
local_attn_scores_per_head = local_attn_weights[:, head].reshape(-1)
|
||||||
|
local_flat = local_attn_scores_per_head.clone().detach().cpu()
|
||||||
|
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
|
||||||
|
local_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||||
|
plt.stairs(local_hist.tolist(), local_bins.tolist())
|
||||||
|
plt.title(f"local_attn_score_{head}")
|
||||||
|
# set arbitrarily but we want to see those peaks!!
|
||||||
|
plt.ylim((0, 1000))
|
||||||
|
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
# Mask heads if we want to
|
|
||||||
if head_mask is not None:
|
|
||||||
attn_weights = attn_weights * head_mask
|
|
||||||
|
|
||||||
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
|
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||||
attn_output = torch.einsum("bshn, bhsnd-> bhsd", attn_scores, value)
|
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
|
||||||
|
local_attn_output = torch.matmul(local_attn_weights, local_value)
|
||||||
|
|
||||||
|
# TODO: do we need flamingo style gating
|
||||||
|
# of output_gate.tanh * attn_output
|
||||||
|
attn_output = mem_attn_output + local_attn_output
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||||
@ -361,9 +521,11 @@ class LetheAttention(nn.Module):
|
|||||||
query,
|
query,
|
||||||
key.transpose(1, 2),
|
key.transpose(1, 2),
|
||||||
beta=1.0,
|
beta=1.0,
|
||||||
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
|
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
|
||||||
)
|
)
|
||||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||||
|
if self.memory:
|
||||||
|
attn_scores = attn_scores * self.scale.exp()
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_scores.dtype).min
|
mask_value = torch.finfo(attn_scores.dtype).min
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
@ -413,7 +575,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||||
self.cos_cached = emb.cos()[None, None, :, :]
|
self.cos_cached = emb.cos()[None, None, :, :]
|
||||||
self.sin_cached = emb.sin()[None, None, :, :]
|
self.sin_cached = emb.sin()[None, None, :, :]
|
||||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
return self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype), self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
@ -448,12 +610,12 @@ class LetheMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LetheLayer(nn.Module):
|
class LetheLayer(nn.Module):
|
||||||
def __init__(self, config, memory_attention=False, index=None):
|
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index)
|
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
|
||||||
self.mlp = LetheMLP(config)
|
self.mlp = LetheMLP(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -465,7 +627,9 @@ class LetheLayer(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_mem_attn: Optional[bool] = True,
|
log_attn_scores: Optional[bool] = False,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
save_kv: Optional[bool] = True
|
||||||
):
|
):
|
||||||
ln_hidden_states = self.input_layernorm(hidden_states)
|
ln_hidden_states = self.input_layernorm(hidden_states)
|
||||||
attention_layer_outputs = self.attention(
|
attention_layer_outputs = self.attention(
|
||||||
@ -476,7 +640,9 @@ class LetheLayer(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_mem_attn=use_mem_attn,
|
log_attn_scores=log_attn_scores,
|
||||||
|
step=step,
|
||||||
|
save_kv=save_kv,
|
||||||
)
|
)
|
||||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||||
outputs = attention_layer_outputs[1:]
|
outputs = attention_layer_outputs[1:]
|
||||||
@ -503,7 +669,7 @@ class LetheLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LetheModel(LethePreTrainedModel):
|
class LetheModel(LethePreTrainedModel):
|
||||||
def __init__(self, config, index):
|
def __init__(self, config, index, tracker=None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@ -511,7 +677,8 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList([LetheLayer(config,
|
self.layers = nn.ModuleList([LetheLayer(config,
|
||||||
memory_attention=i+1 == config.memory_attn_layer,
|
memory_attention=i+1 == config.memory_attn_layer,
|
||||||
index=index if i+1 == config.memory_attn_layer else None)
|
index=index if i+1 == config.memory_attn_layer else None,
|
||||||
|
tracker=tracker)
|
||||||
for i in range(config.num_hidden_layers)])
|
for i in range(config.num_hidden_layers)])
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
@ -538,7 +705,9 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
use_mem_attn: Optional[bool] = True,
|
log_attn_scores: Optional[bool] = False,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
save_kv: Optional[bool] = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
@ -631,7 +800,7 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for layer_past
|
# None for layer_past
|
||||||
return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
|
return module(*inputs, use_cache, None, output_attentions, log_attn_scores, step, save_kv)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@ -651,7 +820,9 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_mem_attn=use_mem_attn,
|
log_attn_scores=log_attn_scores,
|
||||||
|
step=step,
|
||||||
|
save_kv=save_kv,
|
||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
@ -678,10 +849,10 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
class LetheForCausalLM(LethePreTrainedModel):
|
class LetheForCausalLM(LethePreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config, index):
|
def __init__(self, config, index, tracker=None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.gpt_neox = LetheModel(config, index)
|
self.gpt_neox = LetheModel(config, index, tracker=tracker)
|
||||||
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -709,7 +880,9 @@ class LetheForCausalLM(LethePreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
use_mem_attn: Optional[bool] = True,
|
log_attn_scores: Optional[bool] = None,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
save_kv: Optional[bool] = True,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
@ -763,7 +936,9 @@ class LetheForCausalLM(LethePreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
use_mem_attn=use_mem_attn
|
log_attn_scores=log_attn_scores,
|
||||||
|
step=step,
|
||||||
|
save_kv=save_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
21
gpt4all/models/lethe/test_index.py
Normal file
21
gpt4all/models/lethe/test_index.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||||
|
|
||||||
|
|
||||||
|
index = MemoryIndex(256,
|
||||||
|
575000,
|
||||||
|
8
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = np.random.randn(32, 8, 1024, 256)
|
||||||
|
values = np.random.randn(32, 8, 1024, 256)
|
||||||
|
start = time.time()
|
||||||
|
index.add(keys, values)
|
||||||
|
print(f"index.add time: {time.time() - start}")
|
||||||
|
|
||||||
|
print(index.key_indices[0].index.ntotal)
|
||||||
|
queries = np.random.randn(32, 8, 1024, 256)
|
||||||
|
start = time.time()
|
||||||
|
index.knn_query(queries, k=32)
|
||||||
|
print(f"index.knn_query time: {time.time() - start}")
|
@ -23,8 +23,9 @@ print("loading model")
|
|||||||
dimension = config.max_position_embeddings * config.hidden_size
|
dimension = config.max_position_embeddings * config.hidden_size
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
index = MemoryIndex(head_size,
|
index = MemoryIndex(head_size,
|
||||||
500_000,
|
5_000_000,
|
||||||
config.num_attention_heads
|
# 2 since multi-query attention and storing one each for key and value
|
||||||
|
config.num_attention_heads,
|
||||||
)
|
)
|
||||||
model = LetheForCausalLM(config, index)
|
model = LetheForCausalLM(config, index)
|
||||||
model.to("cuda:0")
|
model.to("cuda:0")
|
||||||
@ -69,7 +70,7 @@ with torch.no_grad():
|
|||||||
for chunk_start in tqdm(range(0, memories.shape[0], 32)):
|
for chunk_start in tqdm(range(0, memories.shape[0], 32)):
|
||||||
chunk_end = min(memories.shape[0], chunk_start + 32)
|
chunk_end = min(memories.shape[0], chunk_start + 32)
|
||||||
mem_chunk = memories[chunk_start:chunk_end].to(model.device)
|
mem_chunk = memories[chunk_start:chunk_end].to(model.device)
|
||||||
model(input_ids=mem_chunk, labels=None)
|
model(input_ids=mem_chunk, labels=None,)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
2
gpt4all/models/pythia_retro/__init__.py
Normal file
2
gpt4all/models/pythia_retro/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .configuration_pythia_retro import PythiaRetroConfig
|
||||||
|
from .modeling_pythia_retro import PythiaRetroForCausalLM
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers import AutoTokenizer, get_scheduler, AutoConfig
|
from transformers import AutoTokenizer, get_scheduler, AutoConfig
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
@ -12,6 +13,8 @@ from tqdm import tqdm
|
|||||||
from gpt4all.models import LetheForCausalLM
|
from gpt4all.models import LetheForCausalLM
|
||||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||||
import wandb
|
import wandb
|
||||||
|
import pyarrow as pa
|
||||||
|
from pyarrow import feather
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
@ -22,39 +25,58 @@ def format_metrics(metrics, split, prefix=""):
|
|||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, config, val_dataloader, main_process=False):
|
def calculate_per_example_loss(logits, labels):
|
||||||
|
lm_logits = logits[:, :-1, :].contiguous()
|
||||||
|
lm_labels = labels[:, 1:].contiguous()
|
||||||
|
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
|
||||||
|
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
|
||||||
|
|
||||||
|
# return tensor of shape (B,) where B is the batch size
|
||||||
|
return loss.cpu().tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, index, config, val_dataloader, main_process=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
|
||||||
head_size = model.config.hidden_size // model.config.num_attention_heads
|
ids = []
|
||||||
index = MemoryIndex(head_size,
|
losses = []
|
||||||
config["num_memories_per_index"],
|
|
||||||
model.config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(val_dataloader, disable=not main_process):
|
for batch in tqdm(val_dataloader, disable=not main_process):
|
||||||
|
batch["id"] = batch["id"].detach().cpu()
|
||||||
memories = batch["retrieved_context"]
|
memories = batch["retrieved_context"]
|
||||||
|
|
||||||
# need to set to eval so we don't do mem attn as it's slow
|
memories = memories[:, :config["num_neighbors_to_store"], :]
|
||||||
model.eval()
|
memories = memories.reshape(-1, memories.shape[-1])
|
||||||
|
|
||||||
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
||||||
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
||||||
mem_chunk = memories[chunk_start:chunk_end]
|
mem_chunk = memories[chunk_start:chunk_end]
|
||||||
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
|
model(input_ids=mem_chunk)
|
||||||
|
|
||||||
qa_inputs = batch["input_ids"]
|
qa_inputs = batch["input_ids"]
|
||||||
qa_labels = batch["labels"]
|
qa_labels = batch["labels"]
|
||||||
outputs = model(input_ids=qa_inputs,
|
outputs = model(input_ids=qa_inputs,
|
||||||
labels=qa_labels,
|
labels=qa_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
del memories
|
||||||
|
torch.cuda.empty_cache()
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
||||||
|
|
||||||
|
|
||||||
val_loss.update(loss_values["loss"])
|
|
||||||
index.reset()
|
index.reset()
|
||||||
|
val_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
return val_loss
|
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
|
||||||
|
|
||||||
|
losses.extend(per_example_loss)
|
||||||
|
ids.extend(batch["id"].tolist())
|
||||||
|
|
||||||
|
ids = pa.array(ids)
|
||||||
|
losses = pa.array(losses)
|
||||||
|
schema = pa.schema([("loss", pa.float64()), ("id", pa.int32())])
|
||||||
|
table = pa.Table.from_arrays([losses, ids], schema=schema)
|
||||||
|
return val_loss, table
|
||||||
|
|
||||||
|
|
||||||
def train(accelerator, config):
|
def train(accelerator, config):
|
||||||
@ -72,6 +94,7 @@ def train(accelerator, config):
|
|||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer)
|
train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
if accelerator.state.deepspeed_plugin is not None:
|
if accelerator.state.deepspeed_plugin is not None:
|
||||||
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
"gradient_accumulation_steps"
|
"gradient_accumulation_steps"
|
||||||
@ -97,6 +120,7 @@ def train(accelerator, config):
|
|||||||
memory_attn_layer=config["memory_attn_layer"],
|
memory_attn_layer=config["memory_attn_layer"],
|
||||||
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
||||||
index=index,
|
index=index,
|
||||||
|
tracker=accelerator.get_tracker("wandb"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -135,16 +159,15 @@ def train(accelerator, config):
|
|||||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||||
model, optimizer, train_dataloader, val_dataloader, scheduler
|
model, optimizer, train_dataloader, val_dataloader, scheduler
|
||||||
)
|
)
|
||||||
scheduler = True
|
use_scheduler = True
|
||||||
else:
|
else:
|
||||||
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
|
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
|
||||||
model, optimizer, train_dataloader, val_dataloader
|
model, optimizer, train_dataloader, val_dataloader
|
||||||
)
|
)
|
||||||
scheduler = False
|
use_scheduler = False
|
||||||
|
|
||||||
|
|
||||||
# setup for saving training states in case preemption
|
# setup for saving training states in case preemption
|
||||||
if scheduler:
|
if use_scheduler:
|
||||||
accelerator.register_for_checkpointing(scheduler)
|
accelerator.register_for_checkpointing(scheduler)
|
||||||
|
|
||||||
if config["checkpoint"]:
|
if config["checkpoint"]:
|
||||||
@ -167,24 +190,33 @@ def train(accelerator, config):
|
|||||||
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
memories = batch["retrieved_context"]
|
memories = batch["retrieved_context"]
|
||||||
|
memories = memories[:, :config["num_neighbors_to_store"], :]
|
||||||
|
memories = memories.reshape(-1, memories.shape[-1])
|
||||||
|
|
||||||
# need to set to eval so we don't do mem attn as it's slow
|
# need to set to eval so we don't do mem attn as it's slow
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
||||||
chunk_end = min(memories.shape[0], chunk_start + 32)
|
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
||||||
mem_chunk = memories[chunk_start:chunk_end]
|
mem_chunk = memories[chunk_start:chunk_end]
|
||||||
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
|
model(input_ids=mem_chunk)
|
||||||
|
|
||||||
|
del memories
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
qa_inputs = batch["input_ids"]
|
qa_inputs = batch["input_ids"]
|
||||||
qa_labels = batch["labels"]
|
qa_labels = batch["labels"]
|
||||||
outputs = model(input_ids=qa_inputs,
|
outputs = model(input_ids=qa_inputs,
|
||||||
labels=qa_labels,
|
labels=qa_labels,
|
||||||
|
log_attn_scores=True,
|
||||||
|
step=curr_step,
|
||||||
|
save_kv=False,
|
||||||
)
|
)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
if config["wandb"]:
|
||||||
|
accelerator.log({"loss": loss}, step=curr_step)
|
||||||
|
|
||||||
index.reset()
|
|
||||||
|
|
||||||
# gather loss before backprop in case of gradient accumulation
|
# gather loss before backprop in case of gradient accumulation
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||||
@ -192,6 +224,8 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
loss = loss / gradient_accumulation_steps
|
loss = loss / gradient_accumulation_steps
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
# !! don't reset index until after backwards pass
|
||||||
|
index.reset()
|
||||||
# get gradient norm of all params
|
# get gradient norm of all params
|
||||||
|
|
||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
@ -202,15 +236,25 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
if scheduler:
|
if use_scheduler:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
||||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/step_{step}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
val_loss = evaluate(model, config, val_dataloader, main_process=main_process)
|
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
|
||||||
|
|
||||||
|
local_rank = accelerator.process_index
|
||||||
|
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
|
||||||
|
|
||||||
log_train = {
|
log_train = {
|
||||||
"train_loss": train_loss.compute()
|
"train_loss": train_loss.compute()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
def rank0_print(msg):
|
def rank0_print(msg):
|
||||||
@ -7,3 +8,20 @@ def rank0_print(msg):
|
|||||||
print(msg)
|
print(msg)
|
||||||
else:
|
else:
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def main_process_first(is_main):
|
||||||
|
yield from _goes_first(is_main)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _goes_first(is_main):
|
||||||
|
if not is_main:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if is_main:
|
||||||
|
dist.barrier()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user