mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-04-28 03:41:58 +00:00
fix: wip mem xf
This commit is contained in:
parent
3677935ce8
commit
55fef489ad
23
configs/eval/evaluate_lethe.yaml
Normal file
23
configs/eval/evaluate_lethe.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
# model/tokenizer
|
||||
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
|
||||
tokenizer_name: "EleutherAI/pythia-1b"
|
||||
version: null
|
||||
gradient_checkpointing: false
|
||||
memory_attn_layer: 12
|
||||
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
|
||||
# dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
|
||||
max_length: 1024
|
||||
batch_size: 1
|
||||
pct_test: 0.05
|
||||
q_column: "question"
|
||||
a_column: "answer"
|
||||
context_column: "text"
|
||||
num_memories_per_index: 2000000
|
||||
num_neighbors_to_retrieve: 2
|
||||
num_neighbors_to_store: 1
|
||||
mem_chunk_size: 64
|
19
configs/inference/synth_data.yaml
Normal file
19
configs/inference/synth_data.yaml
Normal file
@ -0,0 +1,19 @@
|
||||
model_name: "nomic-ai/gpt4all-j"
|
||||
revision: "v1.3-groovy"
|
||||
tokenizer_name: "nomic-ai/gpt4all-j"
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "nomic-ai/cohere-wiki-sbert"
|
||||
batch_size: 32
|
||||
output_path: "synth_qa_pairs"
|
||||
|
||||
# generation
|
||||
max_new_tokens: 75
|
||||
max_generations: 200000
|
||||
|
||||
save_every: 1000
|
||||
|
||||
|
||||
seed: 42
|
@ -10,35 +10,36 @@ memory_attn_layer: 12
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
|
||||
dataset_path: "/home/paperspace/gpt4all/gpt4all/inference/synth_data_combined_174"
|
||||
max_length: 1024
|
||||
batch_size: 64
|
||||
batch_size: 32
|
||||
pct_test: 0.05
|
||||
q_column: "question"
|
||||
a_column: "answers"
|
||||
context_column: "neighbor_text"
|
||||
num_memories_per_index: 500000
|
||||
num_neighbors_to_retrieve: 32
|
||||
a_column: "answer"
|
||||
context_column: "text"
|
||||
num_memories_per_index: 2000000
|
||||
num_neighbors_to_retrieve: 2
|
||||
num_neighbors_to_store: 1
|
||||
mem_chunk_size: 64
|
||||
|
||||
# train dynamics
|
||||
lr: 1.0e-4
|
||||
lr: 1.0e-5
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 100
|
||||
save_every: -1
|
||||
save_every: 100
|
||||
log_grads_every: 100
|
||||
log_lr_every: 10
|
||||
output_dir: "ckpts/mem_attn"
|
||||
output_dir: "ckpts/mem_attn_no_cosine_sim"
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 500
|
||||
warmup_steps: 200
|
||||
num_epochs: 5
|
||||
debug: false
|
||||
scheduler: false
|
||||
|
||||
# logging
|
||||
wandb: false
|
||||
wandb: true
|
||||
wandb_entity: gpt4all
|
||||
wandb_project_name: mem_attn
|
||||
seed: 42
|
||||
|
43
configs/train/pretrain_minipile.yaml
Normal file
43
configs/train/pretrain_minipile.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
# model/tokenizer
|
||||
model_name: "EleutherAI/pythia-1b"
|
||||
tokenizer_name: "EleutherAI/pythia-1b"
|
||||
version: null
|
||||
gradient_checkpointing: true
|
||||
save_name: "nomic-ai/minipille"
|
||||
push_to_hub: false
|
||||
memory_attn_layer: 12
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "JeanKaddour/minipile"
|
||||
max_length: 2048
|
||||
batch_size: 64
|
||||
pct_test: 0.05
|
||||
num_memories_per_index: 5000000
|
||||
mem_chunk_size: 512
|
||||
num_chunks: 10
|
||||
num_neighbors_to_retrieve: 32
|
||||
|
||||
# train dynamics
|
||||
lr: 1.0e-4
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 100
|
||||
save_every: -1
|
||||
log_grads_every: 100
|
||||
log_lr_every: 10
|
||||
output_dir: "ckpts/minipile"
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 500
|
||||
num_epochs: 5
|
||||
debug: false
|
||||
scheduler: false
|
||||
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: gpt4all
|
||||
wandb_project_name: minipile
|
||||
seed: 42
|
||||
|
@ -85,6 +85,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
||||
|
||||
question_col = config["q_column"]
|
||||
answer_col = config["a_column"]
|
||||
context_col = config["context_column"]
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
@ -96,7 +97,8 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
||||
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
|
||||
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
|
||||
# a list of the answer
|
||||
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
|
||||
dataset = dataset.map(lambda ele: {answer_col: [t.strip() for t in ele[answer_col]]}, batched=True)
|
||||
# dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
|
||||
@ -106,19 +108,73 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
||||
|
||||
# tokenize contexts for each example
|
||||
dataset = dataset.map(
|
||||
lambda ele: {"retrieved_context": tokenizer(ele["context"],
|
||||
lambda ele: {"retrieved_context": tokenizer([ele[context_col]],
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True)["input_ids"]},
|
||||
batched=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
columns_to_keep = ["input_ids", "labels", "retrieved_context"]
|
||||
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
|
||||
|
||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||
dataset = dataset.remove_columns(col_names_to_rm)
|
||||
|
||||
if split_dataset:
|
||||
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset.remove_columns("id"),
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
||||
else:
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def load_memory_pretraining_data(config, tokenizer, split="train", split_dataset=True):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
dataset = Dataset.load_from_disk(dataset_path)
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split=split)
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
# e.g. 512 * 10 = 5120 sequence length split up
|
||||
max_length = config["mem_chunk_size"] * config["num_chunks"]
|
||||
dataset = dataset.map(lambda ele: tokenizer(ele["text"], padding="max_length", truncation=True, max_length=max_length),
|
||||
batched=True, **kwargs)
|
||||
|
||||
dataset = dataset.map(lambda x: {"labels": x["input_ids"]}, batched=True, **kwargs)
|
||||
|
||||
|
||||
columns_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
|
||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||
dataset = dataset.remove_columns(col_names_to_rm)
|
||||
|
||||
# we can shuffle since the docs are in one row not split across rows
|
||||
if split_dataset:
|
||||
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
116
gpt4all/eval/eval_squad_atlas_map.py
Normal file
116
gpt4all/eval/eval_squad_atlas_map.py
Normal file
@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gpt4all.models import LetheForCausalLM
|
||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
||||
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||
from gpt4all.utils.read import read_config
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from argparse import ArgumentParser
|
||||
from tqdm import tqdm
|
||||
from nomic import atlas
|
||||
from datasets import load_from_disk
|
||||
|
||||
|
||||
def calc_loss_per_item(logits, labels):
|
||||
lm_logits = logits[:, :-1, :].contiguous()
|
||||
lm_labels = labels[:, 1:].contiguous()
|
||||
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
|
||||
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
|
||||
|
||||
# return tensor of shape (B,) where B is the batch size
|
||||
return loss.cpu().tolist()
|
||||
|
||||
|
||||
def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
|
||||
num_new_tokens = 0
|
||||
with torch.no_grad():
|
||||
while True:
|
||||
if num_new_tokens >= max_new_tokens:
|
||||
break
|
||||
outputs = model(input_ids, save_kv=False)
|
||||
|
||||
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
|
||||
|
||||
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
|
||||
num_new_tokens += 1
|
||||
|
||||
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
|
||||
break
|
||||
|
||||
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"], model_max_length=config["max_length"])
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataloader = load_memory_augmented_data(config, tokenizer, split_dataset=False)
|
||||
|
||||
dataset = load_from_disk(config["dataset_path"])
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model_config = AutoConfig.from_pretrained(config["model_name"])
|
||||
|
||||
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||
index = MemoryIndex(head_size,
|
||||
config["num_memories_per_index"],
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
||||
revision=config['version'] if 'version' in config else None,
|
||||
memory_attn_layer=config["memory_attn_layer"],
|
||||
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
||||
index=index,
|
||||
).to(device)
|
||||
model.eval()
|
||||
|
||||
# Evaluate the model on the SQUAD dataset
|
||||
losses = []
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
memories = batch["retrieved_context"]
|
||||
memories = memories[:, :config["num_neighbors_to_store"], :]
|
||||
memories = memories.reshape(-1, memories.shape[-1])
|
||||
|
||||
# need to set to eval so we don't do mem attn as it's slow
|
||||
with torch.no_grad():
|
||||
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
||||
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
||||
mem_chunk = memories[chunk_start:chunk_end]
|
||||
model(input_ids=mem_chunk.to(device))
|
||||
|
||||
del memories
|
||||
torch.cuda.empty_cache()
|
||||
qa_inputs = batch["input_ids"]
|
||||
qa_labels = batch["labels"]
|
||||
for i in range(qa_inputs.shape[0]):
|
||||
inputs = qa_inputs[i].to(device)
|
||||
labels = qa_labels[i].to(device)
|
||||
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
||||
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
|
||||
print(tokenizer.decode(inputs, skip_special_tokens=True))
|
||||
|
||||
|
||||
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
|
||||
# losses.extend(batch_loss)
|
||||
index.reset()
|
||||
|
||||
|
||||
|
||||
dataset = dataset.add_column("loss", losses)
|
||||
|
||||
dataset.save_to_disk("eval_squad_atlas_map")
|
||||
|
||||
|
||||
|
57
gpt4all/inference/combine_synth_data.py
Normal file
57
gpt4all/inference/combine_synth_data.py
Normal file
@ -0,0 +1,57 @@
|
||||
import glob
|
||||
from argparse import ArgumentParser
|
||||
from datasets import Dataset, load_from_disk, concatenate_datasets
|
||||
|
||||
|
||||
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
|
||||
|
||||
def load_synth_data(data_dir):
|
||||
files = glob.glob(data_dir + "/*")
|
||||
|
||||
ds = concatenate_datasets([load_from_disk(f) for f in files])
|
||||
table = ds.data.table
|
||||
|
||||
filtered = table.filter(table["valid"])
|
||||
ds = Dataset.from_dict(filtered.to_pydict())
|
||||
return ds
|
||||
|
||||
|
||||
def remove_prompt(examples):
|
||||
outputs = {"text": [], "generated": [], "question": [], "answer": []}
|
||||
for context, generated in zip(examples["text"], examples["generated"]):
|
||||
prompt_w_ctx = PROMPT.format(context)
|
||||
gen_wo_ctx = generated[len(prompt_w_ctx):]
|
||||
|
||||
assert prompt_w_ctx not in gen_wo_ctx
|
||||
|
||||
question = gen_wo_ctx.split("Answer:")[0].replace("Question:", "").strip()
|
||||
answer = gen_wo_ctx.split("Answer:")[1].strip()
|
||||
|
||||
outputs["text"].append(context)
|
||||
outputs["generated"].append(gen_wo_ctx)
|
||||
outputs["question"].append(question)
|
||||
outputs["answer"].append(answer)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
|
||||
def combine_synth_data(data_dir):
|
||||
ds = load_synth_data(data_dir)
|
||||
|
||||
ds = ds.map(lambda ele: remove_prompt(ele), batched=True, num_proc=64)
|
||||
|
||||
ds.save_to_disk(f"synth_data_combined_{len(ds)/1000:.0f}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--dataset_dir", type=str, required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
combine_synth_data(args.dataset_dir)
|
||||
|
||||
|
149
gpt4all/inference/generate_synth_data.py
Normal file
149
gpt4all/inference/generate_synth_data.py
Normal file
@ -0,0 +1,149 @@
|
||||
from argparse import ArgumentParser
|
||||
from datasets import load_dataset, Dataset
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from accelerate.utils import set_seed
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, DefaultDataCollator
|
||||
from gpt4all.utils.read import read_config
|
||||
from gpt4all.utils.distributed_utils import rank0_print, main_process_first
|
||||
from tqdm import tqdm
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow as pa
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
PROMPT = "Write a question answer pair based on the following context. If the context isn't specific enough, ignore and return 'No question answer pair`. Context : {}\n"
|
||||
|
||||
|
||||
def prepare_data(config, tokenizer, num_processes, local_rank):
|
||||
dataset = load_dataset(config["dataset_path"], split="train")
|
||||
dataset = dataset.remove_columns("embedding")
|
||||
|
||||
shuffled = dataset.shuffle(seed=config["seed"])
|
||||
indices = shuffled[:config["max_generations"]]["id"]
|
||||
|
||||
table = dataset.data
|
||||
mask = pc.is_in(table["id"], value_set=pa.array(indices, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
|
||||
# convert from pyarrow to Dataset
|
||||
orig_dataset = Dataset.from_dict(filtered_table.to_pydict())
|
||||
|
||||
dataset = orig_dataset.map(lambda ele: {"prompted_text": [PROMPT.format(context) for context in ele["text"]]},
|
||||
batched=True,
|
||||
num_proc=config["num_proc"] if "num_proc" in config else None)
|
||||
|
||||
dataset = dataset.map(lambda ele: {"prompt_len": [len(prompt) for prompt in ele["prompted_text"]]}, batched=True,
|
||||
num_proc=config["num_proc"] if "num_proc" in config else None)
|
||||
|
||||
dataset = dataset.sort("prompt_len")
|
||||
dataset = dataset.map(lambda ele: tokenizer(ele["prompted_text"], return_tensors="pt", padding="longest", truncation=True,
|
||||
max_length=tokenizer.model_max_length - config["max_new_tokens"]), batched=True,
|
||||
batch_size=num_processes * config["batch_size"],
|
||||
)
|
||||
|
||||
columns_to_keep = ["id", "input_ids", "attention_mask"]
|
||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||
|
||||
dataset = dataset.remove_columns(col_names_to_rm)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
num_replicas=num_processes,
|
||||
rank=local_rank,
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
sampler=sampler,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
return dataloader, orig_dataset
|
||||
|
||||
def generate_data(config):
|
||||
set_seed(config["seed"])
|
||||
|
||||
rank0_print(f"World size: {dist.get_world_size()}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||
# since we're doing generation, pad left for autoregressive generation
|
||||
tokenizer.padding_side = "left"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
num_processes = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
dataloader, dataset = prepare_data(config, tokenizer, num_processes, local_rank)
|
||||
|
||||
dist.barrier()
|
||||
print(dataset[:10]["id"])
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
revision=config["revision"] if "revision" in config else None,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.bfloat16,)
|
||||
model.to(f"cuda:{local_rank}")
|
||||
|
||||
synth_data = []
|
||||
valid = []
|
||||
ids = []
|
||||
total_valid = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(tqdm(dataloader, disable=local_rank != 0)):
|
||||
# keep this simple for now, can add temperature and other sampling techniques later
|
||||
generated = model.generate(batch["input_ids"].to(model.device),
|
||||
attention_mask=batch["attention_mask"].to(model.device),
|
||||
max_new_tokens=config["max_new_tokens"])
|
||||
|
||||
decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
||||
num_valid = ["\nQuestion:" in t and "\nAnswer:" in t for t in decoded]
|
||||
rank0_print(f"Num valid: {sum(num_valid)/ len(num_valid):.2f}")
|
||||
total_valid += sum(num_valid)
|
||||
|
||||
synth_data.extend(decoded)
|
||||
valid.extend(num_valid)
|
||||
ids.extend(batch["id"].tolist())
|
||||
|
||||
if i > 0 and i % config["save_every"] == 0:
|
||||
table = dataset.data.table
|
||||
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
|
||||
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
|
||||
joined = filtered_table.join(chunk_table, "id")
|
||||
curr_dataset = Dataset.from_dict(joined.to_pydict())
|
||||
curr_dataset.save_to_disk(f'{config["output_path"]}/chunk_{i}_rank_{local_rank}')
|
||||
|
||||
table = dataset.data.table
|
||||
mask = pc.is_in(table["id"], value_set=pa.array(ids, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
|
||||
chunk_table = pa.Table.from_pydict({"id": ids, "generated": synth_data, "valid": valid})
|
||||
joined = filtered_table.join(chunk_table, "id")
|
||||
full_dataset = Dataset.from_dict(joined.to_pydict())
|
||||
full_dataset.save_to_disk(f'{config["output_path"]}_{config["max_generations"]}_rank_{local_rank}')
|
||||
|
||||
rank0_print(f"Total valid: {total_valid}/{config['max_generations']}")
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = read_config(args.config)
|
||||
|
||||
generate_data(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
main()
|
@ -111,6 +111,7 @@ class LetheConfig(PretrainedConfig):
|
||||
memory_attn_layer=9,
|
||||
num_neighbors_to_retrieve=32,
|
||||
num_neighbors_stored=128,
|
||||
attn_scale_init=20.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -133,4 +134,5 @@ class LetheConfig(PretrainedConfig):
|
||||
# index of cross attention layer to add
|
||||
self.memory_attn_layer = memory_attn_layer
|
||||
self.num_neighbors_to_retrieve = num_neighbors_to_retrieve
|
||||
self.num_neighbors_stored = num_neighbors_stored
|
||||
self.num_neighbors_stored = num_neighbors_stored
|
||||
self.attn_scale_init = attn_scale_init
|
@ -12,8 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch PythiaSeek model."""
|
||||
""" PyTorch Lethe model."""
|
||||
|
||||
import wandb
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -29,8 +33,9 @@ from transformers.modeling_outputs import (
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from gpt4all.models.lethe import LetheConfig
|
||||
import hnswlib
|
||||
import numpy as np
|
||||
import faiss
|
||||
import faiss.contrib.torch_utils
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -40,17 +45,18 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"EleutherAI/gpt-neox-20b",
|
||||
]
|
||||
|
||||
# TODO: understand why Phil only does this per batch and doens't persist across many batches -> he uses multi-query attention
|
||||
# TODO: do we need to implement masking for the dense vectors we pull from?
|
||||
# TODO: i think phil is using a memmapped database to pull out rather than using the index
|
||||
|
||||
|
||||
class HNSWIndex:
|
||||
def __init__(self, max_memories, dimension):
|
||||
# num_memories will be batch size * num_neighbors
|
||||
# can memmap this too like
|
||||
self.index = hnswlib.Index(space="l2", dim=dimension)
|
||||
self.index.init_index(max_elements=max_memories, ef_construction=50, M=16)
|
||||
self.index = faiss.IndexHNSWFlat(dimension, 16, faiss.METRIC_INNER_PRODUCT)
|
||||
# taking params from: https://www.pinecone.io/learn/vector-indexes/#hnsw-implementation
|
||||
# and https://www.pinecone.io/learn/hnsw/#hnsw-performance
|
||||
# seems like efConstruction dictates how long the index takes to build
|
||||
# and efSearch and M (second arg to faiss.Index) dictates how long it takes to search
|
||||
self.index.hnsw.efConstruction = 16
|
||||
self.index.hnsw.efSearch = 32
|
||||
self.max_memories = max_memories
|
||||
self.dimension = dimension
|
||||
|
||||
@ -60,45 +66,65 @@ class HNSWIndex:
|
||||
|
||||
def query(self, query, k=1):
|
||||
# hack what should we do here?
|
||||
if self.index.get_current_count() == 0:
|
||||
return np.ones((query.shape[0], k, query.shape[1]), dtype=np.float32)
|
||||
if self.index.ntotal == 0:
|
||||
return np.ones((query.shape[0], k), dtype=np.int32)
|
||||
|
||||
assert query.ndim == 2
|
||||
bs_seq_len, _ = query.shape
|
||||
_, labels = self.index.search(np.ascontiguousarray(query), k=k)
|
||||
|
||||
labels, _ = self.index.knn_query(query, k=k)
|
||||
neighbors = torch.tensor(self.index.get_items(labels.reshape(-1)))
|
||||
neighbors = neighbors.reshape((bs_seq_len, k, query.shape[1]))
|
||||
|
||||
assert neighbors.ndim == 3
|
||||
assert neighbors.shape[0] == bs_seq_len
|
||||
|
||||
return neighbors
|
||||
return labels
|
||||
|
||||
def add(self, memories):
|
||||
assert memories.ndim == 2
|
||||
bs_seq_len, _ = memories.shape
|
||||
|
||||
ids = np.arange(self.idx_offset, self.idx_offset + bs_seq_len)
|
||||
|
||||
self.index.add_items(memories, ids)
|
||||
|
||||
self.idx_offset += bs_seq_len
|
||||
return self.index.add(np.ascontiguousarray(memories))
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.index = hnswlib.Index(space="l2", dim=self.dimension)
|
||||
self.index.init_index(max_elements=self.max_memories, ef_construction=50, M=16)
|
||||
self.index.reset()
|
||||
|
||||
|
||||
class NumpyKNNIndex:
|
||||
def __init__(self, max_memories, dimension):
|
||||
# num_memories will be batch size * num_neighbors
|
||||
# can memmap this too like
|
||||
self.index = np.zeros((max_memories, dimension), dtype=np.float32)
|
||||
self.max_memories = max_memories
|
||||
self.dimension = dimension
|
||||
|
||||
# if we want to allow for insertion of len(elements) > max_memories
|
||||
# we need to figure out a way to get the most recent memories
|
||||
self.idx_offset = 0
|
||||
|
||||
def query(self, query, k=1):
|
||||
# hack what should we do here?
|
||||
if self.index.sum() == 0:
|
||||
return np.ones((query.shape[0], k), dtype=np.int32)
|
||||
|
||||
dots = query.dot(self.index[:self.idx_offset].T)
|
||||
labels = np.argsort(dots, axis=1)[:, -k:]
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def add(self, memories):
|
||||
self.index[self.idx_offset:self.idx_offset + memories.shape[0]] = memories
|
||||
self.idx_offset += memories.shape[0]
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.index.reset()
|
||||
|
||||
|
||||
|
||||
class MemoryIndex:
|
||||
def __init__(self, hidden_dim, num_mems, nheads):
|
||||
# we store an index for each k/v for each head
|
||||
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
||||
self.value_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
||||
self.nheads = nheads
|
||||
|
||||
# NOTE: we are storing kv pairs, instead indices for both keys and values
|
||||
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
||||
|
||||
shape = (num_mems, nheads, 2, hidden_dim)
|
||||
self.kv_pairs = np.zeros(shape, dtype=np.float32)
|
||||
self.idx_offset = 0
|
||||
|
||||
def add(self, keys, values):
|
||||
# k/v are (bs, num_attention_heads, seq_len, head_size)
|
||||
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
|
||||
@ -106,22 +132,30 @@ class MemoryIndex:
|
||||
|
||||
for head in range(self.nheads):
|
||||
self.key_indices[head].add(reshaped_keys[:, head, :])
|
||||
self.value_indices[head].add(reshaped_values[:, head, :])
|
||||
|
||||
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
|
||||
|
||||
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
|
||||
raise ValueError("Not enough memory!")
|
||||
|
||||
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
|
||||
self.idx_offset += kv_pairs.shape[0]
|
||||
|
||||
def knn_query(self, query, k=1):
|
||||
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
|
||||
|
||||
mem_keys = []
|
||||
mem_values = []
|
||||
mem_indices = []
|
||||
|
||||
# this is prob so so slow
|
||||
# we can prob make this better
|
||||
for head in range(self.nheads):
|
||||
knn_keys = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
|
||||
knn_values = self.value_indices[head].query(reshaped_query[:, head, :], k=k)
|
||||
|
||||
mem_keys.append(knn_keys)
|
||||
mem_values.append(knn_values)
|
||||
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
|
||||
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
|
||||
|
||||
mem_keys.append(kv_pairs[:, :, 0, :])
|
||||
mem_values.append(kv_pairs[:, :, 1, :])
|
||||
mem_indices.append(knn_indices)
|
||||
|
||||
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
|
||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
||||
@ -131,13 +165,14 @@ class MemoryIndex:
|
||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
||||
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||
|
||||
return mem_keys, mem_values
|
||||
return mem_keys, mem_values, np.stack(mem_indices, axis=1)
|
||||
|
||||
|
||||
def reset(self):
|
||||
for head in range(self.nheads):
|
||||
self.key_indices[head].reset()
|
||||
self.value_indices[head].reset()
|
||||
|
||||
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
|
||||
|
||||
|
||||
class LethePreTrainedModel(PreTrainedModel):
|
||||
@ -171,7 +206,7 @@ class LethePreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class LetheAttention(nn.Module):
|
||||
def __init__(self, config, memory_attention=False, index=None):
|
||||
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -188,21 +223,24 @@ class LetheAttention(nn.Module):
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
||||
)
|
||||
self.register_buffer(
|
||||
if not memory_attention:
|
||||
self.register_buffer(
|
||||
"norm_factor",
|
||||
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.memory = False
|
||||
|
||||
if memory_attention:
|
||||
self.scale = nn.Parameter(torch.ones(self.num_attention_heads, 1, 1) * math.log(config.attn_scale_init))
|
||||
self.memory = True
|
||||
self.alpha = nn.Parameter(torch.zeros(self.num_attention_heads))
|
||||
self.num_neighbors = config.num_neighbors_to_retrieve
|
||||
# for testing, just using np array since it's easy
|
||||
self.index = index
|
||||
self.tracker = tracker
|
||||
|
||||
|
||||
def forward(
|
||||
@ -214,7 +252,9 @@ class LetheAttention(nn.Module):
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
log_attn_scores: Optional[bool] = False,
|
||||
step: Optional[int] = None,
|
||||
save_kv: Optional[bool] = True,
|
||||
):
|
||||
has_layer_past = layer_past is not None
|
||||
|
||||
@ -234,6 +274,12 @@ class LetheAttention(nn.Module):
|
||||
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
||||
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
||||
|
||||
# if self.memory:
|
||||
if self.memory:
|
||||
# QKNorm: https://arxiv.org/abs/2010.04245
|
||||
query = F.normalize(query, dim=-1)
|
||||
key = F.normalize(key, dim=-1)
|
||||
|
||||
# Compute rotary embeddings on rotary_ndims
|
||||
query_rot = query[..., : self.rotary_ndims]
|
||||
query_pass = query[..., self.rotary_ndims :]
|
||||
@ -257,27 +303,38 @@ class LetheAttention(nn.Module):
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = (key, value) if use_cache else None
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
# TODO: need to do masking??
|
||||
# memory attention
|
||||
if self.memory:
|
||||
# get knns
|
||||
# since we do an eval batch w context before, let's not do the expensive step until we need to
|
||||
# [batch, knn, num_attention_heads, seq_len, head_size]
|
||||
if use_mem_attn:
|
||||
knn_keys, knn_values = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||
mem_attn = self._mem_attn(query,
|
||||
knn_keys.to(query.device),
|
||||
knn_values.to(query.device),
|
||||
attention_mask,
|
||||
head_mask
|
||||
)
|
||||
if save_kv:
|
||||
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||
|
||||
expanded_alpha = self.alpha[None, :, None, None]
|
||||
attn_output = (attn_output * (1 - expanded_alpha)) + (mem_attn * expanded_alpha)
|
||||
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||
if log_attn_scores:
|
||||
batch_size = query.shape[0]
|
||||
seq_len = query.shape[-2]
|
||||
|
||||
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||
key_labels = knn_labels // seq_len
|
||||
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
|
||||
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||
# calculate the accuracy
|
||||
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
|
||||
|
||||
self.tracker.log({"retrieved_acc": key_acc}, step=step)
|
||||
|
||||
attn_output = self._mem_attn(query,
|
||||
knn_keys.to(query.device).to(value.dtype),
|
||||
knn_values.to(query.device).to(value.dtype),
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
log_attn_scores=log_attn_scores,
|
||||
step=step,
|
||||
knn_labels=knn_labels,
|
||||
)
|
||||
else:
|
||||
# Normal self-attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||
@ -315,28 +372,131 @@ class LetheAttention(nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
def _mem_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
def _mem_attn(self,
|
||||
query,
|
||||
knn_key,
|
||||
knn_value,
|
||||
local_key,
|
||||
local_value,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
log_attn_scores=False,
|
||||
step=None,
|
||||
knn_labels=None):
|
||||
# local self-attention
|
||||
# q: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
# k,v: [bs, num_attention_heads, seq_len, knn, attn_head_size]
|
||||
query_length = query.size(-2)
|
||||
key_length = local_key.size(-2)
|
||||
|
||||
attn_scores = torch.einsum("bhsd, bhsnd-> bshn", query, key)
|
||||
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
||||
attn_scores = attn_scores / self.norm_factor
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
# softmax over knns
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
local_attn_scores = torch.matmul(query, local_key.transpose(-1, -2))
|
||||
scale = self.scale.exp()
|
||||
|
||||
local_attn_scores = local_attn_scores * scale
|
||||
|
||||
mask_value = torch.finfo(local_attn_scores.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=local_attn_scores.dtype).to(local_attn_scores.device)
|
||||
local_attn_scores = torch.where(causal_mask, local_attn_scores, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_scores = attn_scores + attention_mask
|
||||
local_attn_scores = local_attn_scores + attention_mask
|
||||
|
||||
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
|
||||
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
||||
mem_attn_scores = mem_attn_scores * scale
|
||||
|
||||
attn_scores = torch.cat((mem_attn_scores, local_attn_scores), dim=-1)
|
||||
|
||||
# softmax over knns
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = attn_weights.to(local_value.dtype)
|
||||
|
||||
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
|
||||
if log_attn_scores:
|
||||
# (bs, seq_len, num_attention_heads, knn) probabilities
|
||||
# curate (x,y) pairs
|
||||
# where x is attention weight, y is accuracy of retrieved token
|
||||
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
|
||||
key_labels = knn_labels // seq_len
|
||||
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
|
||||
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||
|
||||
bin_width = 0.05
|
||||
|
||||
# Calculate the number of bins
|
||||
num_bins = int(1 / bin_width)
|
||||
|
||||
# Create empty lists for storing bin probabilities and accuracies
|
||||
bin_probabilities = []
|
||||
bin_accuracies = []
|
||||
|
||||
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
|
||||
correct_keys = correct_keys.reshape(-1).tolist()
|
||||
|
||||
# Iterate over each bin
|
||||
for i in range(num_bins):
|
||||
bin_lower = i * bin_width
|
||||
bin_upper = (i + 1) * bin_width
|
||||
|
||||
# Filter data points within the current bin range
|
||||
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
|
||||
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
|
||||
|
||||
# Calculate accuracy for the bin
|
||||
total = len(bin_x_values)
|
||||
correct = sum(bin_y_values)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# Store the probability and accuracy for the bin
|
||||
bin_probabilities.append((bin_lower + bin_upper) / 2)
|
||||
bin_accuracies.append(accuracy)
|
||||
|
||||
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
|
||||
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
|
||||
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
|
||||
|
||||
|
||||
if log_attn_scores:
|
||||
# this def won't work well on multi-gpu machines
|
||||
num_attention_heads = mem_attn_weights.size(1)
|
||||
for head in range(num_attention_heads):
|
||||
mem_attn_score_per_head = mem_attn_weights[:, head].reshape(-1)
|
||||
mem_flat = mem_attn_score_per_head.clone().detach().cpu()
|
||||
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
|
||||
mem_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
|
||||
plt.title(f"mem_attn_score_{head}")
|
||||
# set arbitrarily but we want to see those peaks!!
|
||||
plt.ylim((0, 1000))
|
||||
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
|
||||
plt.close()
|
||||
|
||||
|
||||
local_attn_scores_per_head = local_attn_weights[:, head].reshape(-1)
|
||||
local_flat = local_attn_scores_per_head.clone().detach().cpu()
|
||||
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
|
||||
local_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||
plt.stairs(local_hist.tolist(), local_bins.tolist())
|
||||
plt.title(f"local_attn_score_{head}")
|
||||
# set arbitrarily but we want to see those peaks!!
|
||||
plt.ylim((0, 1000))
|
||||
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
|
||||
plt.close()
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
attn_output = torch.einsum("bshn, bhsnd-> bhsd", attn_scores, value)
|
||||
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
|
||||
local_attn_output = torch.matmul(local_attn_weights, local_value)
|
||||
|
||||
# TODO: do we need flamingo style gating
|
||||
# of output_gate.tanh * attn_output
|
||||
attn_output = mem_attn_output + local_attn_output
|
||||
|
||||
return attn_output
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
@ -361,9 +521,11 @@ class LetheAttention(nn.Module):
|
||||
query,
|
||||
key.transpose(1, 2),
|
||||
beta=1.0,
|
||||
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
|
||||
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
|
||||
)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
if self.memory:
|
||||
attn_scores = attn_scores * self.scale.exp()
|
||||
|
||||
mask_value = torch.finfo(attn_scores.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
@ -413,7 +575,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.cos_cached = emb.cos()[None, None, :, :]
|
||||
self.sin_cached = emb.sin()[None, None, :, :]
|
||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
||||
return self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype), self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@ -448,12 +610,12 @@ class LetheMLP(nn.Module):
|
||||
|
||||
|
||||
class LetheLayer(nn.Module):
|
||||
def __init__(self, config, memory_attention=False, index=None):
|
||||
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index)
|
||||
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
|
||||
self.mlp = LetheMLP(config)
|
||||
|
||||
def forward(
|
||||
@ -465,7 +627,9 @@ class LetheLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
log_attn_scores: Optional[bool] = False,
|
||||
step: Optional[int] = None,
|
||||
save_kv: Optional[bool] = True
|
||||
):
|
||||
ln_hidden_states = self.input_layernorm(hidden_states)
|
||||
attention_layer_outputs = self.attention(
|
||||
@ -476,7 +640,9 @@ class LetheLayer(nn.Module):
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
use_mem_attn=use_mem_attn,
|
||||
log_attn_scores=log_attn_scores,
|
||||
step=step,
|
||||
save_kv=save_kv,
|
||||
)
|
||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||
outputs = attention_layer_outputs[1:]
|
||||
@ -503,7 +669,7 @@ class LetheLayer(nn.Module):
|
||||
|
||||
|
||||
class LetheModel(LethePreTrainedModel):
|
||||
def __init__(self, config, index):
|
||||
def __init__(self, config, index, tracker=None):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
@ -511,7 +677,8 @@ class LetheModel(LethePreTrainedModel):
|
||||
|
||||
self.layers = nn.ModuleList([LetheLayer(config,
|
||||
memory_attention=i+1 == config.memory_attn_layer,
|
||||
index=index if i+1 == config.memory_attn_layer else None)
|
||||
index=index if i+1 == config.memory_attn_layer else None,
|
||||
tracker=tracker)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@ -538,7 +705,9 @@ class LetheModel(LethePreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
log_attn_scores: Optional[bool] = False,
|
||||
step: Optional[int] = None,
|
||||
save_kv: Optional[bool] = True,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
@ -631,7 +800,7 @@ class LetheModel(LethePreTrainedModel):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for layer_past
|
||||
return module(*inputs, use_cache, None, output_attentions, use_mem_attn)
|
||||
return module(*inputs, use_cache, None, output_attentions, log_attn_scores, step, save_kv)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@ -651,7 +820,9 @@ class LetheModel(LethePreTrainedModel):
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
use_mem_attn=use_mem_attn,
|
||||
log_attn_scores=log_attn_scores,
|
||||
step=step,
|
||||
save_kv=save_kv,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
@ -678,10 +849,10 @@ class LetheModel(LethePreTrainedModel):
|
||||
class LetheForCausalLM(LethePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config, index):
|
||||
def __init__(self, config, index, tracker=None):
|
||||
super().__init__(config)
|
||||
|
||||
self.gpt_neox = LetheModel(config, index)
|
||||
self.gpt_neox = LetheModel(config, index, tracker=tracker)
|
||||
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -709,7 +880,9 @@ class LetheForCausalLM(LethePreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_mem_attn: Optional[bool] = True,
|
||||
log_attn_scores: Optional[bool] = None,
|
||||
step: Optional[int] = None,
|
||||
save_kv: Optional[bool] = True,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
@ -763,7 +936,9 @@ class LetheForCausalLM(LethePreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_mem_attn=use_mem_attn
|
||||
log_attn_scores=log_attn_scores,
|
||||
step=step,
|
||||
save_kv=save_kv,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
21
gpt4all/models/lethe/test_index.py
Normal file
21
gpt4all/models/lethe/test_index.py
Normal file
@ -0,0 +1,21 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||
|
||||
|
||||
index = MemoryIndex(256,
|
||||
575000,
|
||||
8
|
||||
)
|
||||
|
||||
keys = np.random.randn(32, 8, 1024, 256)
|
||||
values = np.random.randn(32, 8, 1024, 256)
|
||||
start = time.time()
|
||||
index.add(keys, values)
|
||||
print(f"index.add time: {time.time() - start}")
|
||||
|
||||
print(index.key_indices[0].index.ntotal)
|
||||
queries = np.random.randn(32, 8, 1024, 256)
|
||||
start = time.time()
|
||||
index.knn_query(queries, k=32)
|
||||
print(f"index.knn_query time: {time.time() - start}")
|
@ -23,8 +23,9 @@ print("loading model")
|
||||
dimension = config.max_position_embeddings * config.hidden_size
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
index = MemoryIndex(head_size,
|
||||
500_000,
|
||||
config.num_attention_heads
|
||||
5_000_000,
|
||||
# 2 since multi-query attention and storing one each for key and value
|
||||
config.num_attention_heads,
|
||||
)
|
||||
model = LetheForCausalLM(config, index)
|
||||
model.to("cuda:0")
|
||||
@ -69,7 +70,7 @@ with torch.no_grad():
|
||||
for chunk_start in tqdm(range(0, memories.shape[0], 32)):
|
||||
chunk_end = min(memories.shape[0], chunk_start + 32)
|
||||
mem_chunk = memories[chunk_start:chunk_end].to(model.device)
|
||||
model(input_ids=mem_chunk, labels=None)
|
||||
model(input_ids=mem_chunk, labels=None,)
|
||||
|
||||
model.train()
|
||||
|
||||
|
2
gpt4all/models/pythia_retro/__init__.py
Normal file
2
gpt4all/models/pythia_retro/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .configuration_pythia_retro import PythiaRetroConfig
|
||||
from .modeling_pythia_retro import PythiaRetroForCausalLM
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer, get_scheduler, AutoConfig
|
||||
import torch
|
||||
from torch.optim import AdamW
|
||||
@ -12,6 +13,8 @@ from tqdm import tqdm
|
||||
from gpt4all.models import LetheForCausalLM
|
||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
||||
import wandb
|
||||
import pyarrow as pa
|
||||
from pyarrow import feather
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
@ -22,39 +25,58 @@ def format_metrics(metrics, split, prefix=""):
|
||||
return log
|
||||
|
||||
|
||||
def evaluate(model, config, val_dataloader, main_process=False):
|
||||
def calculate_per_example_loss(logits, labels):
|
||||
lm_logits = logits[:, :-1, :].contiguous()
|
||||
lm_labels = labels[:, 1:].contiguous()
|
||||
loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1), reduction="none")
|
||||
loss = loss.reshape(labels.shape[0], -1).mean(dim=-1)
|
||||
|
||||
# return tensor of shape (B,) where B is the batch size
|
||||
return loss.cpu().tolist()
|
||||
|
||||
|
||||
def evaluate(model, index, config, val_dataloader, main_process=False):
|
||||
model.eval()
|
||||
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
|
||||
head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||
index = MemoryIndex(head_size,
|
||||
config["num_memories_per_index"],
|
||||
model.config.num_attention_heads
|
||||
)
|
||||
|
||||
ids = []
|
||||
losses = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_dataloader, disable=not main_process):
|
||||
batch["id"] = batch["id"].detach().cpu()
|
||||
memories = batch["retrieved_context"]
|
||||
|
||||
# need to set to eval so we don't do mem attn as it's slow
|
||||
model.eval()
|
||||
memories = memories[:, :config["num_neighbors_to_store"], :]
|
||||
memories = memories.reshape(-1, memories.shape[-1])
|
||||
|
||||
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
||||
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
||||
mem_chunk = memories[chunk_start:chunk_end]
|
||||
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
|
||||
model(input_ids=mem_chunk)
|
||||
|
||||
qa_inputs = batch["input_ids"]
|
||||
qa_labels = batch["labels"]
|
||||
outputs = model(input_ids=qa_inputs,
|
||||
labels=qa_labels,
|
||||
)
|
||||
|
||||
del memories
|
||||
torch.cuda.empty_cache()
|
||||
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
||||
|
||||
|
||||
val_loss.update(loss_values["loss"])
|
||||
index.reset()
|
||||
val_loss.update(loss_values["loss"])
|
||||
|
||||
return val_loss
|
||||
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
|
||||
|
||||
losses.extend(per_example_loss)
|
||||
ids.extend(batch["id"].tolist())
|
||||
|
||||
ids = pa.array(ids)
|
||||
losses = pa.array(losses)
|
||||
schema = pa.schema([("loss", pa.float64()), ("id", pa.int32())])
|
||||
table = pa.Table.from_arrays([losses, ids], schema=schema)
|
||||
return val_loss, table
|
||||
|
||||
|
||||
def train(accelerator, config):
|
||||
@ -72,6 +94,7 @@ def train(accelerator, config):
|
||||
with accelerator.main_process_first():
|
||||
train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer)
|
||||
|
||||
|
||||
if accelerator.state.deepspeed_plugin is not None:
|
||||
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"gradient_accumulation_steps"
|
||||
@ -97,6 +120,7 @@ def train(accelerator, config):
|
||||
memory_attn_layer=config["memory_attn_layer"],
|
||||
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
||||
index=index,
|
||||
tracker=accelerator.get_tracker("wandb"),
|
||||
)
|
||||
|
||||
|
||||
@ -135,16 +159,15 @@ def train(accelerator, config):
|
||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, val_dataloader, scheduler
|
||||
)
|
||||
scheduler = True
|
||||
use_scheduler = True
|
||||
else:
|
||||
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, val_dataloader
|
||||
)
|
||||
scheduler = False
|
||||
|
||||
use_scheduler = False
|
||||
|
||||
# setup for saving training states in case preemption
|
||||
if scheduler:
|
||||
if use_scheduler:
|
||||
accelerator.register_for_checkpointing(scheduler)
|
||||
|
||||
if config["checkpoint"]:
|
||||
@ -167,24 +190,33 @@ def train(accelerator, config):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
memories = batch["retrieved_context"]
|
||||
memories = memories[:, :config["num_neighbors_to_store"], :]
|
||||
memories = memories.reshape(-1, memories.shape[-1])
|
||||
|
||||
# need to set to eval so we don't do mem attn as it's slow
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
|
||||
chunk_end = min(memories.shape[0], chunk_start + 32)
|
||||
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
|
||||
mem_chunk = memories[chunk_start:chunk_end]
|
||||
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
|
||||
model(input_ids=mem_chunk)
|
||||
|
||||
del memories
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model.train()
|
||||
qa_inputs = batch["input_ids"]
|
||||
qa_labels = batch["labels"]
|
||||
outputs = model(input_ids=qa_inputs,
|
||||
labels=qa_labels,
|
||||
log_attn_scores=True,
|
||||
step=curr_step,
|
||||
save_kv=False,
|
||||
)
|
||||
loss = outputs.loss
|
||||
if config["wandb"]:
|
||||
accelerator.log({"loss": loss}, step=curr_step)
|
||||
|
||||
index.reset()
|
||||
|
||||
# gather loss before backprop in case of gradient accumulation
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||
@ -192,6 +224,8 @@ def train(accelerator, config):
|
||||
|
||||
loss = loss / gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
# !! don't reset index until after backwards pass
|
||||
index.reset()
|
||||
# get gradient norm of all params
|
||||
|
||||
# log LR in case something weird happens
|
||||
@ -202,15 +236,25 @@ def train(accelerator, config):
|
||||
|
||||
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
if scheduler:
|
||||
if use_scheduler:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/step_{step}",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||
val_loss = evaluate(model, config, val_dataloader, main_process=main_process)
|
||||
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
|
||||
|
||||
local_rank = accelerator.process_index
|
||||
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
|
||||
|
||||
log_train = {
|
||||
"train_loss": train_loss.compute()
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch.distributed as dist
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
def rank0_print(msg):
|
||||
@ -7,3 +8,20 @@ def rank0_print(msg):
|
||||
print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def main_process_first(is_main):
|
||||
yield from _goes_first(is_main)
|
||||
|
||||
|
||||
|
||||
def _goes_first(is_main):
|
||||
if not is_main:
|
||||
dist.barrier()
|
||||
|
||||
yield
|
||||
|
||||
if is_main:
|
||||
dist.barrier()
|
||||
|
Loading…
Reference in New Issue
Block a user