mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-14 06:05:34 +00:00
wip
This commit is contained in:
parent
55fef489ad
commit
3128db96ca
@ -14,7 +14,7 @@
|
|||||||
},
|
},
|
||||||
"gradient_clipping": 1.0,
|
"gradient_clipping": 1.0,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 1,
|
||||||
"offload_param": {
|
"offload_param": {
|
||||||
"device": "none"
|
"device": "none"
|
||||||
},
|
},
|
||||||
@ -35,5 +35,15 @@
|
|||||||
],
|
],
|
||||||
"eps": 1e-08
|
"eps": 1e-08
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": 0,
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,9 +1,10 @@
|
|||||||
# model/tokenizer
|
# model/tokenizer
|
||||||
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/mem_attn/step_1000"
|
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/qk_no_norm/step_500"
|
||||||
tokenizer_name: "EleutherAI/pythia-1b"
|
tokenizer_name: "EleutherAI/pythia-1b"
|
||||||
version: null
|
version: null
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: true
|
||||||
memory_attn_layer: 12
|
memory_attn_layer: 12
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
|
@ -5,7 +5,7 @@ version: null
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
save_name: "nomic-ai/lethe"
|
save_name: "nomic-ai/lethe"
|
||||||
push_to_hub: false
|
push_to_hub: false
|
||||||
memory_attn_layer: 12
|
memory_attn_layer: [9, 12, 15]
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
streaming: false
|
streaming: false
|
||||||
@ -17,7 +17,7 @@ pct_test: 0.05
|
|||||||
q_column: "question"
|
q_column: "question"
|
||||||
a_column: "answer"
|
a_column: "answer"
|
||||||
context_column: "text"
|
context_column: "text"
|
||||||
num_memories_per_index: 2000000
|
num_memories_per_index: 2048
|
||||||
num_neighbors_to_retrieve: 2
|
num_neighbors_to_retrieve: 2
|
||||||
num_neighbors_to_store: 1
|
num_neighbors_to_store: 1
|
||||||
mem_chunk_size: 64
|
mem_chunk_size: 64
|
||||||
@ -26,15 +26,15 @@ mem_chunk_size: 64
|
|||||||
lr: 1.0e-5
|
lr: 1.0e-5
|
||||||
min_lr: 0
|
min_lr: 0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
eval_every: 100
|
eval_every: 250
|
||||||
save_every: 100
|
save_every: 250
|
||||||
log_grads_every: 100
|
log_grads_every: 100
|
||||||
log_lr_every: 10
|
log_lr_every: 10
|
||||||
output_dir: "ckpts/mem_attn_no_cosine_sim"
|
output_dir: "ckpts/qk_no_norm"
|
||||||
checkpoint: null
|
checkpoint: null
|
||||||
lora: false
|
lora: false
|
||||||
warmup_steps: 200
|
warmup_steps: 200
|
||||||
num_epochs: 5
|
num_epochs: 2
|
||||||
debug: false
|
debug: false
|
||||||
scheduler: false
|
scheduler: false
|
||||||
|
|
||||||
@ -42,5 +42,4 @@ scheduler: false
|
|||||||
wandb: true
|
wandb: true
|
||||||
wandb_entity: gpt4all
|
wandb_entity: gpt4all
|
||||||
wandb_project_name: mem_attn
|
wandb_project_name: mem_attn
|
||||||
seed: 42
|
seed: 42
|
||||||
|
|
@ -10,20 +10,21 @@ memory_attn_layer: 12
|
|||||||
# dataset
|
# dataset
|
||||||
streaming: false
|
streaming: false
|
||||||
num_proc: 64
|
num_proc: 64
|
||||||
dataset_path: "JeanKaddour/minipile"
|
dataset_path: "pg19"
|
||||||
max_length: 2048
|
max_length: 2048
|
||||||
batch_size: 64
|
seq_len: 512
|
||||||
|
segments: 16
|
||||||
|
batch_size: 16
|
||||||
pct_test: 0.05
|
pct_test: 0.05
|
||||||
num_memories_per_index: 5000000
|
num_memories_per_index: 100000
|
||||||
mem_chunk_size: 512
|
mem_chunk_size: 512
|
||||||
num_chunks: 10
|
|
||||||
num_neighbors_to_retrieve: 32
|
num_neighbors_to_retrieve: 32
|
||||||
|
|
||||||
# train dynamics
|
# train dynamics
|
||||||
lr: 1.0e-4
|
lr: 2.0e-4
|
||||||
min_lr: 0
|
min_lr: 0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
eval_every: 100
|
eval_every: 250
|
||||||
save_every: -1
|
save_every: -1
|
||||||
log_grads_every: 100
|
log_grads_every: 100
|
||||||
log_lr_every: 10
|
log_lr_every: 10
|
||||||
@ -38,6 +39,6 @@ scheduler: false
|
|||||||
# logging
|
# logging
|
||||||
wandb: true
|
wandb: true
|
||||||
wandb_entity: gpt4all
|
wandb_entity: gpt4all
|
||||||
wandb_project_name: minipile
|
wandb_project_name: enwik8
|
||||||
seed: 42
|
seed: 42
|
||||||
|
|
62
gpt4all/data/enwik8.py
Normal file
62
gpt4all/data/enwik8.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from transformers import DefaultDataCollator
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class EnWik8Dataset(Dataset):
|
||||||
|
def __init__(self, data, seq_len):
|
||||||
|
# pyarrow chunked array
|
||||||
|
self.data = torch.from_numpy(data)
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
full_seq = self.data[index].long()
|
||||||
|
return full_seq.cuda()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_enwik8_dataloader(config, tokenizer):
|
||||||
|
ds = load_dataset(config["dataset_path"], split="train")
|
||||||
|
|
||||||
|
ds = ds.train_test_split(test_size=0.05, seed=config['seed'])
|
||||||
|
|
||||||
|
train_ds, val_ds = ds["train"], ds["test"]
|
||||||
|
|
||||||
|
keep_cols = ["input_ids"]
|
||||||
|
|
||||||
|
train_ds = train_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
|
||||||
|
train_ds = train_ds.sort("len")
|
||||||
|
train_ds = train_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
|
||||||
|
batched=True,
|
||||||
|
batch_size=config["batch_size"])
|
||||||
|
|
||||||
|
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
|
||||||
|
train_ds = train_ds.remove_columns(remove_cols)
|
||||||
|
|
||||||
|
val_ds = val_ds.map(lambda x: {"len": [len(t) for t in x["text"]]}, batched=True)
|
||||||
|
val_ds = val_ds.sort("len")
|
||||||
|
val_ds = val_ds.map(lambda x: tokenizer(x["text"], padding="longest", truncation=True, return_tensors="pt"),
|
||||||
|
batched=True,
|
||||||
|
batch_size=config["batch_size"])
|
||||||
|
|
||||||
|
remove_cols = [col for col in train_ds.column_names if col not in keep_cols]
|
||||||
|
val_ds = val_ds.remove_columns(remove_cols)
|
||||||
|
|
||||||
|
train_dl = DataLoader(train_ds,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=DefaultDataCollator())
|
||||||
|
|
||||||
|
val_dl = DataLoader(val_ds,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=DefaultDataCollator())
|
||||||
|
|
||||||
|
return train_dl, val_dl
|
@ -1,6 +1,6 @@
|
|||||||
import glob
|
import glob
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset, load_from_disk
|
||||||
import os
|
import os
|
||||||
import hnswlib
|
import hnswlib
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -12,20 +12,23 @@ def load_data(config, tokenizer):
|
|||||||
dataset_path = config["dataset_path"]
|
dataset_path = config["dataset_path"]
|
||||||
|
|
||||||
if os.path.exists(dataset_path):
|
if os.path.exists(dataset_path):
|
||||||
if os.path.isdir(dataset_path):
|
dataset = load_from_disk(dataset_path)
|
||||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
# if os.path.isdir(dataset_path):
|
||||||
else:
|
# files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||||
files = [dataset_path]
|
# else:
|
||||||
|
# files = [dataset_path]
|
||||||
|
|
||||||
print(f"Reading files {files}")
|
# print(f"Reading files {files}")
|
||||||
|
|
||||||
dataset = load_dataset("json", data_files=files, split="train")
|
# dataset = load_dataset("json", data_files=files, split="train")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset(dataset_path, split="train")
|
dataset = load_dataset(dataset_path, split="train")
|
||||||
|
|
||||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||||
|
|
||||||
|
dataset = dataset.map(lambda x: {"prompt": [text + " " + question for text, question in zip(x["text"], x["question"])]}, batched=True)
|
||||||
|
|
||||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||||
|
|
||||||
if config["streaming"] is False:
|
if config["streaming"] is False:
|
||||||
@ -33,19 +36,27 @@ def load_data(config, tokenizer):
|
|||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
cols_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||||
|
|
||||||
# tokenize inputs and return labels and attention mask
|
# tokenize inputs and return labels and attention mask
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=["source", "prompt"],
|
# remove_columns=["source", "prompt"],
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cols_to_remove = [col for col in train_dataset.column_names if col not in cols_to_keep]
|
||||||
|
train_dataset = train_dataset.remove_columns(cols_to_remove)
|
||||||
|
|
||||||
val_dataset = val_dataset.map(
|
val_dataset = val_dataset.map(
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
|
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "answer"),
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=["source", "prompt"],
|
# remove_columns=["source", "prompt"],
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
cols_to_remove = [col for col in val_dataset.column_names if col not in cols_to_keep]
|
||||||
|
val_dataset = val_dataset.remove_columns(cols_to_remove)
|
||||||
|
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
val_dataset = val_dataset.with_format("torch")
|
val_dataset = val_dataset.with_format("torch")
|
||||||
@ -56,12 +67,14 @@ def load_data(config, tokenizer):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
@ -5,7 +5,7 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
|
|||||||
|
|
||||||
# hacky backward compatible
|
# hacky backward compatible
|
||||||
different_eos = tokenizer.eos_token != "</s>"
|
different_eos = tokenizer.eos_token != "</s>"
|
||||||
out = {"labels": [], "input_ids": []}
|
out = {"labels": [], "input_ids": [], "attention_mask": []}
|
||||||
for prompt, response in zip(examples[input_col], examples[target_col]):
|
for prompt, response in zip(examples[input_col], examples[target_col]):
|
||||||
if different_eos:
|
if different_eos:
|
||||||
if response.count("</s> \n") > 0:
|
if response.count("</s> \n") > 0:
|
||||||
@ -42,9 +42,10 @@ def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
|
|||||||
print(response)
|
print(response)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
tokenized = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)
|
||||||
out["labels"].append(labels)
|
out["labels"].append(labels)
|
||||||
out["input_ids"].append(input_tokens)
|
out["input_ids"].append(tokenized["input_ids"])
|
||||||
|
out["attention_mask"].append(tokenized["attention_mask"])
|
||||||
|
|
||||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase
|
|||||||
if encoder_column != "encoder_hidden_states":
|
if encoder_column != "encoder_hidden_states":
|
||||||
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
|
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
|
||||||
|
|
||||||
columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"]
|
columns_to_keep = ["input_ids", "attention_mask", "labels", "encoder_hidden_states"]
|
||||||
|
|
||||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||||
dataset = dataset.remove_columns(col_names_to_rm)
|
dataset = dataset.remove_columns(col_names_to_rm)
|
||||||
@ -115,7 +115,7 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
columns_to_keep = ["id", "input_ids", "labels", "retrieved_context"]
|
columns_to_keep = ["id", "input_ids", "attention_mask", "labels", "retrieved_context"]
|
||||||
|
|
||||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||||
dataset = dataset.remove_columns(col_names_to_rm)
|
dataset = dataset.remove_columns(col_names_to_rm)
|
||||||
@ -128,12 +128,16 @@ def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=T
|
|||||||
train_dataset.remove_columns("id"),
|
train_dataset.remove_columns("id"),
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=config["batch_size"],
|
batch_size=config["batch_size"],
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from gpt4all.models import LetheForCausalLM
|
from gpt4all.models import LetheForCausalLM
|
||||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
|
||||||
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
||||||
from gpt4all.train.metrics import f1_score, exact_match_score
|
from gpt4all.train.metrics import f1_score, exact_match_score
|
||||||
from gpt4all.utils.read import read_config
|
from gpt4all.utils.read import read_config
|
||||||
@ -28,17 +28,21 @@ def greedy_search(input_ids, model, tokenizer, max_new_tokens=100):
|
|||||||
while True:
|
while True:
|
||||||
if num_new_tokens >= max_new_tokens:
|
if num_new_tokens >= max_new_tokens:
|
||||||
break
|
break
|
||||||
outputs = model(input_ids, save_kv=False)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, save_kv=False)
|
||||||
|
|
||||||
new_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
|
next_token_idx = torch.argmax((input_ids == tokenizer.pad_token_id).type(torch.float32))
|
||||||
|
# -1 because logits at last position predict next token
|
||||||
|
new_token = torch.argmax(outputs.logits[:, next_token_idx - 1, :], dim=-1)
|
||||||
|
|
||||||
|
input_ids[:, next_token_idx] = new_token
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, new_tokens.unsqueeze(1)], dim=-1)
|
|
||||||
num_new_tokens += 1
|
num_new_tokens += 1
|
||||||
|
|
||||||
if torch.equal(input_ids[0, -1].cpu(), torch.tensor(tokenizer.eos_token_id)):
|
if torch.equal(new_token.cpu(), torch.tensor(tokenizer.eos_token_id)):
|
||||||
break
|
break
|
||||||
|
|
||||||
print(tokenizer.batch_decode(input_ids, skip_special_tokens=True))
|
print(f"GENERATED: {tokenizer.batch_decode(input_ids, skip_special_tokens=True)}")
|
||||||
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
@ -63,10 +67,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||||||
model_config = AutoConfig.from_pretrained(config["model_name"])
|
model_config = AutoConfig.from_pretrained(config["model_name"])
|
||||||
|
|
||||||
head_size = model_config.hidden_size // model_config.num_attention_heads
|
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||||
index = MemoryIndex(head_size,
|
index = BatchedMemory(config["batch_size"],
|
||||||
config["num_memories_per_index"],
|
head_size,
|
||||||
model_config.num_attention_heads
|
config["num_memories_per_index"],
|
||||||
)
|
model_config.num_attention_heads,
|
||||||
|
)
|
||||||
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
||||||
revision=config['version'] if 'version' in config else None,
|
revision=config['version'] if 'version' in config else None,
|
||||||
memory_attn_layer=config["memory_attn_layer"],
|
memory_attn_layer=config["memory_attn_layer"],
|
||||||
@ -90,16 +95,19 @@ with torch.no_grad():
|
|||||||
mem_chunk = memories[chunk_start:chunk_end]
|
mem_chunk = memories[chunk_start:chunk_end]
|
||||||
model(input_ids=mem_chunk.to(device))
|
model(input_ids=mem_chunk.to(device))
|
||||||
|
|
||||||
del memories
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
qa_inputs = batch["input_ids"]
|
qa_inputs = batch["input_ids"]
|
||||||
qa_labels = batch["labels"]
|
qa_labels = batch["labels"]
|
||||||
for i in range(qa_inputs.shape[0]):
|
for i in range(qa_inputs.shape[0]):
|
||||||
inputs = qa_inputs[i].to(device)
|
inputs = qa_inputs[i].to(device)
|
||||||
|
print(f"EXPECTED: {tokenizer.decode(inputs, skip_special_tokens=True)}")
|
||||||
labels = qa_labels[i].to(device)
|
labels = qa_labels[i].to(device)
|
||||||
|
|
||||||
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
||||||
greedy_search(inputs[:cutoff.item()].unsqueeze(0).to(device), model, tokenizer)
|
inputs[cutoff:] = tokenizer.pad_token_id
|
||||||
print(tokenizer.decode(inputs, skip_special_tokens=True))
|
greedy_search(inputs.unsqueeze(0).to(device), model, tokenizer)
|
||||||
|
print(f"CONTEXT: {tokenizer.decode(memories[i], skip_special_tokens=True)}")
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
|
# batch_loss = calc_loss_per_item(outputs.logits, qa_labels.to(device))
|
||||||
|
96
gpt4all/eval/eval_squad_atlas_map/dataset_info.json
Normal file
96
gpt4all/eval/eval_squad_atlas_map/dataset_info.json
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
{
|
||||||
|
"builder_name": "squad",
|
||||||
|
"citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n",
|
||||||
|
"config_name": "plain_text",
|
||||||
|
"dataset_size": 89819092,
|
||||||
|
"description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
|
||||||
|
"download_checksums": {
|
||||||
|
"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json": {
|
||||||
|
"num_bytes": 30288272,
|
||||||
|
"checksum": null
|
||||||
|
},
|
||||||
|
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json": {
|
||||||
|
"num_bytes": 4854279,
|
||||||
|
"checksum": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"download_size": 35142551,
|
||||||
|
"features": {
|
||||||
|
"id": {
|
||||||
|
"dtype": "string",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"title": {
|
||||||
|
"dtype": "string",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"dtype": "string",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"question": {
|
||||||
|
"dtype": "string",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"answers": {
|
||||||
|
"feature": {
|
||||||
|
"text": {
|
||||||
|
"dtype": "string",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"answer_start": {
|
||||||
|
"dtype": "int32",
|
||||||
|
"_type": "Value"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"neighbor_ids": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "uint64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"neighbor_text": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "string",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"loss": {
|
||||||
|
"dtype": "float64",
|
||||||
|
"_type": "Value"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"homepage": "https://rajpurkar.github.io/SQuAD-explorer/",
|
||||||
|
"license": "",
|
||||||
|
"size_in_bytes": 124961643,
|
||||||
|
"splits": {
|
||||||
|
"train": {
|
||||||
|
"name": "train",
|
||||||
|
"num_bytes": 79346108,
|
||||||
|
"num_examples": 87599,
|
||||||
|
"dataset_name": "squad"
|
||||||
|
},
|
||||||
|
"validation": {
|
||||||
|
"name": "validation",
|
||||||
|
"num_bytes": 10472984,
|
||||||
|
"num_examples": 10570,
|
||||||
|
"dataset_name": "squad"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"task_templates": [
|
||||||
|
{
|
||||||
|
"task": "question-answering-extractive"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"version": {
|
||||||
|
"version_str": "1.0.0",
|
||||||
|
"description": "",
|
||||||
|
"major": 1,
|
||||||
|
"minor": 0,
|
||||||
|
"patch": 0
|
||||||
|
}
|
||||||
|
}
|
16
gpt4all/eval/eval_squad_atlas_map/state.json
Normal file
16
gpt4all/eval/eval_squad_atlas_map/state.json
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"_data_files": [
|
||||||
|
{
|
||||||
|
"filename": "data-00000-of-00002.arrow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "data-00001-of-00002.arrow"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"_fingerprint": "c178f5c8269012a2",
|
||||||
|
"_format_columns": null,
|
||||||
|
"_format_kwargs": {},
|
||||||
|
"_format_type": null,
|
||||||
|
"_output_all_columns": false,
|
||||||
|
"_split": "validation"
|
||||||
|
}
|
42
gpt4all/eval/eval_synthetic.py
Normal file
42
gpt4all/eval/eval_synthetic.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from gpt4all.data.instruction_tuning_dataloader import load_data
|
||||||
|
from gpt4all.utils.read import read_config
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = read_config(args.config)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
model = model.half().to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Evaluate the model on the SQUAD dataset
|
||||||
|
f1s = []
|
||||||
|
exact_matches = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(val_dataloader):
|
||||||
|
inputs = batch["input_ids"].to(device)
|
||||||
|
labels = batch["labels"].to(device)
|
||||||
|
cutoff = torch.argmax((labels != -100).type(torch.float32))
|
||||||
|
outputs = model.generate(inputs[:, :cutoff], max_new_tokens=100)
|
||||||
|
print(f"Predicted: {tokenizer.batch_decode(outputs, skip_special_tokens=True)}")
|
||||||
|
print(f"Ground truth: {tokenizer.batch_decode(inputs[:, cutoff:], skip_special_tokens=True)}")
|
||||||
|
print(tokenizer.batch_decode(inputs, skip_special_tokens=True))
|
||||||
|
|
@ -2,7 +2,6 @@ from .gpt_jr.configuration_gpt_jr import GPTJRConfig
|
|||||||
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
|
||||||
|
|
||||||
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
|
||||||
from .pythia_retro import PythiaRetroForCausalLM, PythiaRetroConfig
|
|
||||||
from .lethe import LetheConfig, LetheForCausalLM
|
from .lethe import LetheConfig, LetheForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ import wandb
|
|||||||
import math
|
import math
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import plotly.express as px
|
||||||
|
import pandas as pd
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -121,60 +123,106 @@ class MemoryIndex:
|
|||||||
# NOTE: we are storing kv pairs, instead indices for both keys and values
|
# NOTE: we are storing kv pairs, instead indices for both keys and values
|
||||||
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
self.key_indices = [HNSWIndex(num_mems, hidden_dim) for _ in range(nheads)]
|
||||||
|
|
||||||
shape = (num_mems, nheads, 2, hidden_dim)
|
shape = (nheads, num_mems, 2, hidden_dim)
|
||||||
self.kv_pairs = np.zeros(shape, dtype=np.float32)
|
self.kv_pairs = np.zeros(shape, dtype=np.float32)
|
||||||
self.idx_offset = 0
|
self.idx_offset = 0
|
||||||
|
|
||||||
def add(self, keys, values):
|
def add(self, keys, values):
|
||||||
# k/v are (bs, num_attention_heads, seq_len, head_size)
|
# k/v are (num_attention_heads, seq_len, head_size)
|
||||||
reshaped_keys = keys.reshape(keys.shape[0] * keys.shape[2], keys.shape[1], keys.shape[3])
|
# keys = keys.reshape(keys.shape[1], keys.shape[0], keys.shape[2])
|
||||||
reshaped_values = values.reshape(values.shape[0] * values.shape[2], values.shape[1], values.shape[3])
|
# values = values.reshape(values.shape[1], values.shape[0], values.shape[2])
|
||||||
|
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
self.key_indices[head].add(reshaped_keys[:, head, :])
|
self.key_indices[head].add(keys[head, :, :])
|
||||||
|
|
||||||
kv_pairs = np.stack((reshaped_keys, reshaped_values), axis=2)
|
kv_pairs = np.stack((keys, values), axis=2)
|
||||||
|
|
||||||
if self.idx_offset + kv_pairs.shape[0] > self.kv_pairs.shape[0]:
|
if self.idx_offset + kv_pairs.shape[1] > self.kv_pairs.shape[1]:
|
||||||
raise ValueError("Not enough memory!")
|
# reset to 0 to overwrite oldest memories
|
||||||
|
self.idx_offet = 0
|
||||||
|
|
||||||
self.kv_pairs[self.idx_offset:self.idx_offset + kv_pairs.shape[0]] = kv_pairs
|
self.kv_pairs[:, self.idx_offset:self.idx_offset + kv_pairs.shape[1]] = kv_pairs
|
||||||
self.idx_offset += kv_pairs.shape[0]
|
self.idx_offset += kv_pairs.shape[1]
|
||||||
|
|
||||||
def knn_query(self, query, k=1):
|
def knn_query(self, query, k=1):
|
||||||
reshaped_query = query.reshape(query.shape[0] * query.shape[2], query.shape[1], query.shape[3])
|
|
||||||
|
|
||||||
mem_keys = []
|
mem_keys = []
|
||||||
mem_values = []
|
mem_values = []
|
||||||
mem_indices = []
|
mem_indices = []
|
||||||
|
|
||||||
# we can prob make this better
|
# we can prob make this better
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
knn_indices = self.key_indices[head].query(reshaped_query[:, head, :], k=k)
|
knn_indices = self.key_indices[head].query(query[head, :, :], k=k)
|
||||||
kv_pairs = self.kv_pairs[:, head, :, :][knn_indices]
|
kv_pairs = self.kv_pairs[head, :, :, :][knn_indices]
|
||||||
|
|
||||||
mem_keys.append(kv_pairs[:, :, 0, :])
|
mem_keys.append(kv_pairs[:, :, 0, :])
|
||||||
mem_values.append(kv_pairs[:, :, 1, :])
|
mem_values.append(kv_pairs[:, :, 1, :])
|
||||||
mem_indices.append(knn_indices)
|
mem_indices.append(knn_indices)
|
||||||
|
|
||||||
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=1))
|
mem_keys = torch.from_numpy(np.stack(mem_keys, axis=0))
|
||||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
# (num_attention_heads, seq_len, k, head_size)
|
||||||
mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
# mem_keys = mem_keys.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||||
|
|
||||||
mem_values = torch.from_numpy(np.stack(mem_values, axis=1))
|
mem_values = torch.from_numpy(np.stack(mem_values, axis=0))
|
||||||
# (bs, num_attention_heads, seq_len, k, head_size)
|
# (num_attention_heads, seq_len, k, head_size)
|
||||||
mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
# mem_values = mem_values.view(query.shape[:-1] + (k,) + (query.shape[-1],))
|
||||||
|
|
||||||
return mem_keys, mem_values, np.stack(mem_indices, axis=1)
|
return mem_keys, mem_values, np.stack(mem_indices, axis=0)
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for head in range(self.nheads):
|
for head in range(self.nheads):
|
||||||
self.key_indices[head].reset()
|
self.key_indices[head].reset()
|
||||||
|
|
||||||
self.kv_pairs = np.zeros((self.kv_pairs.shape[0], self.nheads, 2, self.kv_pairs.shape[-1]), dtype=np.float32)
|
self.kv_pairs = np.zeros(self.kv_pairs.shape, dtype=np.float32)
|
||||||
|
self.idx_offset = 0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedMemory:
|
||||||
|
def __init__(self, batch_size, hidden_dim, num_mems, nheads):
|
||||||
|
self.indices = [MemoryIndex(hidden_dim, num_mems, nheads) for _ in range(batch_size)]
|
||||||
|
|
||||||
|
|
||||||
|
def add(self, keys, values):
|
||||||
|
for bs in range(len(self.indices)):
|
||||||
|
self.indices[bs].add(keys[bs], values[bs])
|
||||||
|
|
||||||
|
|
||||||
|
def knn_query(self, query, k=1):
|
||||||
|
batched_mem_keys = []
|
||||||
|
batched_mem_values = []
|
||||||
|
batched_labels = []
|
||||||
|
|
||||||
|
for bs in range(len(self.indices)):
|
||||||
|
knn_keys, knn_values, knn_labels = self.indices[bs].knn_query(query[bs], k=k)
|
||||||
|
batched_mem_keys.append(knn_keys)
|
||||||
|
batched_mem_values.append(knn_values)
|
||||||
|
batched_labels.append(knn_labels)
|
||||||
|
|
||||||
|
|
||||||
|
return torch.stack(batched_mem_keys, dim=0), torch.stack(batched_mem_values, dim=0), np.stack(batched_labels, axis=0)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for bs in range(len(self.indices)):
|
||||||
|
self.indices[bs].reset()
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedStorage:
|
||||||
|
def __init__(self, batch_size, hidden_dim, nheads, seq_len):
|
||||||
|
self.indices = np.zeros((batch_size, nheads, seq_len, 2, hidden_dim), dtype=np.float32)
|
||||||
|
|
||||||
|
def knn_query(self, query, k=None):
|
||||||
|
return torch.from_numpy(self.indices[:, :, :, 0, :]), torch.from_numpy(self.indices[:, :, :, 1, :])
|
||||||
|
|
||||||
|
|
||||||
|
def add(self, keys, values):
|
||||||
|
self.indices[:, :, :, 0, :] = keys
|
||||||
|
self.indices[:, :, :, 1, :] = values
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.indices = np.zeros_like(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LethePreTrainedModel(PreTrainedModel):
|
class LethePreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -206,12 +254,13 @@ class LethePreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class LetheAttention(nn.Module):
|
class LetheAttention(nn.Module):
|
||||||
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
def __init__(self, config, memory_attention=False, index=None, layer_idx=None, tracker=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_attention_heads
|
self.head_size = self.hidden_size // self.num_attention_heads
|
||||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
@ -274,7 +323,6 @@ class LetheAttention(nn.Module):
|
|||||||
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
||||||
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
||||||
|
|
||||||
# if self.memory:
|
|
||||||
if self.memory:
|
if self.memory:
|
||||||
# QKNorm: https://arxiv.org/abs/2010.04245
|
# QKNorm: https://arxiv.org/abs/2010.04245
|
||||||
query = F.normalize(query, dim=-1)
|
query = F.normalize(query, dim=-1)
|
||||||
@ -309,17 +357,28 @@ class LetheAttention(nn.Module):
|
|||||||
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
self.index.add(key.cpu().detach().to(torch.float32).numpy(), value.cpu().detach().to(torch.float32).numpy())
|
||||||
|
|
||||||
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
knn_keys, knn_values, knn_labels = self.index.knn_query(query.cpu().detach().to(torch.float32).numpy(), k=self.num_neighbors)
|
||||||
|
# if log_attn_scores:
|
||||||
|
# batch_size = query.shape[0]
|
||||||
|
# seq_len = query.shape[-2]
|
||||||
|
|
||||||
|
# key_labels = knn_labels // seq_len
|
||||||
|
# key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
|
||||||
|
# correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||||
|
# # calculate the accuracy
|
||||||
|
# key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
|
||||||
|
|
||||||
|
# self.tracker.log({"retrieved_acc": key_acc}, step=step)
|
||||||
|
|
||||||
if log_attn_scores:
|
if log_attn_scores:
|
||||||
batch_size = query.shape[0]
|
total_examples = 0
|
||||||
seq_len = query.shape[-2]
|
unique_examples = 0
|
||||||
|
for bs in range(query.shape[0]):
|
||||||
|
for head in range(query.shape[1]):
|
||||||
|
labels_per_head = knn_labels[bs, head, :, 0].tolist()
|
||||||
|
total_examples += len(labels_per_head)
|
||||||
|
unique_examples += len(set(labels_per_head))
|
||||||
|
|
||||||
key_labels = knn_labels // seq_len
|
self.tracker.log({"unique_retrieved_pct": unique_examples / total_examples}, step=step)
|
||||||
key_labels = key_labels.reshape(batch_size, seq_len, self.num_attention_heads, -1)
|
|
||||||
correct_keys = np.equal(key_labels, np.arange(batch_size)[:, np.newaxis, np.newaxis, np.newaxis])
|
|
||||||
# calculate the accuracy
|
|
||||||
key_acc = np.sum(correct_keys) / np.prod(correct_keys.shape)
|
|
||||||
|
|
||||||
self.tracker.log({"retrieved_acc": key_acc}, step=step)
|
|
||||||
|
|
||||||
attn_output = self._mem_attn(query,
|
attn_output = self._mem_attn(query,
|
||||||
knn_keys.to(query.device).to(value.dtype),
|
knn_keys.to(query.device).to(value.dtype),
|
||||||
@ -407,6 +466,7 @@ class LetheAttention(nn.Module):
|
|||||||
local_attn_scores = local_attn_scores + attention_mask
|
local_attn_scores = local_attn_scores + attention_mask
|
||||||
|
|
||||||
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
|
mem_attn_scores = torch.einsum("bhsd, bhsnd-> bhsn", query, knn_key)
|
||||||
|
# mem_attn_scores = torch.matmul(query, knn_key.transpose(-1, -2))
|
||||||
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
# attn_scores: [bs, seq_len, num_attention_heads, knn]
|
||||||
mem_attn_scores = mem_attn_scores * scale
|
mem_attn_scores = mem_attn_scores * scale
|
||||||
|
|
||||||
@ -417,48 +477,56 @@ class LetheAttention(nn.Module):
|
|||||||
attn_weights = attn_weights.to(local_value.dtype)
|
attn_weights = attn_weights.to(local_value.dtype)
|
||||||
|
|
||||||
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
|
mem_attn_weights, local_attn_weights = attn_weights.split([self.num_neighbors, local_attn_scores.size(-1)], dim=-1)
|
||||||
if log_attn_scores:
|
# mem_attn_weights, local_attn_weights = attn_weights.chunk(2, dim=-1)
|
||||||
# (bs, seq_len, num_attention_heads, knn) probabilities
|
|
||||||
# curate (x,y) pairs
|
|
||||||
# where x is attention weight, y is accuracy of retrieved token
|
|
||||||
bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
|
|
||||||
key_labels = knn_labels // seq_len
|
|
||||||
key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
|
|
||||||
correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
|
|
||||||
|
|
||||||
bin_width = 0.05
|
|
||||||
|
|
||||||
# Calculate the number of bins
|
# if log_attn_scores:
|
||||||
num_bins = int(1 / bin_width)
|
# # (bs, seq_len, num_attention_heads, knn) probabilities
|
||||||
|
# # curate (x,y) pairs
|
||||||
|
# # where x is attention weight, y is accuracy of retrieved token
|
||||||
|
# bs, seq_len = mem_attn_weights.size(0), mem_attn_weights.size(2)
|
||||||
|
# key_labels = knn_labels // seq_len
|
||||||
|
# key_labels = key_labels.reshape(bs, self.num_attention_heads, seq_len, -1)
|
||||||
|
# correct_keys = np.equal(key_labels, np.arange(bs)[:, np.newaxis, np.newaxis, np.newaxis])
|
||||||
|
|
||||||
# Create empty lists for storing bin probabilities and accuracies
|
# bin_width = 0.05
|
||||||
bin_probabilities = []
|
|
||||||
bin_accuracies = []
|
|
||||||
|
|
||||||
probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
|
# # Calculate the number of bins
|
||||||
correct_keys = correct_keys.reshape(-1).tolist()
|
# num_bins = int(1 / bin_width)
|
||||||
|
|
||||||
# Iterate over each bin
|
# # Create empty lists for storing bin probabilities and accuracies
|
||||||
for i in range(num_bins):
|
# bin_probabilities = []
|
||||||
bin_lower = i * bin_width
|
# bin_accuracies = []
|
||||||
bin_upper = (i + 1) * bin_width
|
# bin_sizes = []
|
||||||
|
|
||||||
# Filter data points within the current bin range
|
# probs = mem_attn_weights.clone().detach().cpu().numpy().reshape(-1).tolist()
|
||||||
bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
|
# correct_keys = correct_keys.reshape(-1).tolist()
|
||||||
bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
|
|
||||||
|
|
||||||
# Calculate accuracy for the bin
|
# # Iterate over each bin
|
||||||
total = len(bin_x_values)
|
# for i in range(num_bins):
|
||||||
correct = sum(bin_y_values)
|
# bin_lower = i * bin_width
|
||||||
accuracy = correct / total if total > 0 else 0
|
# bin_upper = (i + 1) * bin_width
|
||||||
|
|
||||||
# Store the probability and accuracy for the bin
|
# # Filter data points within the current bin range
|
||||||
bin_probabilities.append((bin_lower + bin_upper) / 2)
|
# bin_x_values = [x for x in probs if bin_lower <= x < bin_upper]
|
||||||
bin_accuracies.append(accuracy)
|
# bin_y_values = [y for j, y in enumerate(correct_keys) if bin_lower <= probs[j] < bin_upper]
|
||||||
|
|
||||||
data = [[x, y] for x, y in zip(bin_probabilities, bin_accuracies)]
|
# # Calculate accuracy for the bin
|
||||||
table = wandb.Table(data=data, columns=["attn_prob", "retrieved_acc"])
|
# total = len(bin_x_values)
|
||||||
self.tracker.log({"attn_vs_acc": wandb.plot.scatter(table, "attn_prob", "retrieved_acc")}, step=step)
|
# correct = sum(bin_y_values)
|
||||||
|
# accuracy = correct / total if total > 0 else 0
|
||||||
|
|
||||||
|
# # Store the probability and accuracy for the bin
|
||||||
|
# bin_probabilities.append((bin_lower + bin_upper) / 2)
|
||||||
|
# bin_accuracies.append(accuracy)
|
||||||
|
# bin_sizes.append(len(bin_x_values))
|
||||||
|
|
||||||
|
# df = pd.DataFrame({"attn_prob": bin_probabilities, "retrieved_acc": bin_accuracies, "bin_size": bin_sizes})
|
||||||
|
|
||||||
|
# fig = px.scatter(df, x="attn_prob", y="retrieved_acc",
|
||||||
|
# color="bin_size", hover_data=["attn_prob", "retrieved_acc", "bin_size"],
|
||||||
|
# title="Attention Probability vs Retrieved Accuracy")
|
||||||
|
# self.tracker.log({"attn_vs_acc": fig}, step=step)
|
||||||
|
|
||||||
|
|
||||||
if log_attn_scores:
|
if log_attn_scores:
|
||||||
@ -470,10 +538,10 @@ class LetheAttention(nn.Module):
|
|||||||
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
|
mem_hist = torch.histc(mem_flat, bins=20, min=0, max=1)
|
||||||
mem_bins = torch.linspace(0, 1, steps=20 + 1)
|
mem_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||||
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
|
plt.stairs(mem_hist.tolist(), mem_bins.tolist())
|
||||||
plt.title(f"mem_attn_score_{head}")
|
plt.title(f"mem_attn_score_{head}_layer_{self.layer_idx}")
|
||||||
# set arbitrarily but we want to see those peaks!!
|
# set arbitrarily but we want to see those peaks!!
|
||||||
plt.ylim((0, 1000))
|
plt.ylim((0, 1000))
|
||||||
self.tracker.log({f"mem_attn_score_{head}": wandb.Image(plt)}, step=step)
|
self.tracker.log({f"mem_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
@ -482,15 +550,16 @@ class LetheAttention(nn.Module):
|
|||||||
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
|
local_hist = torch.histc(local_flat, bins=20, min=0, max=1)
|
||||||
local_bins = torch.linspace(0, 1, steps=20 + 1)
|
local_bins = torch.linspace(0, 1, steps=20 + 1)
|
||||||
plt.stairs(local_hist.tolist(), local_bins.tolist())
|
plt.stairs(local_hist.tolist(), local_bins.tolist())
|
||||||
plt.title(f"local_attn_score_{head}")
|
plt.title(f"local_attn_score_{head}_layer_{self.layer_idx}")
|
||||||
# set arbitrarily but we want to see those peaks!!
|
# set arbitrarily but we want to see those peaks!!
|
||||||
plt.ylim((0, 1000))
|
plt.ylim((0, 1000))
|
||||||
self.tracker.log({f"local_attn_score_{head}": wandb.Image(plt)}, step=step)
|
self.tracker.log({f"local_attn_score_{head}_layer_{self.layer_idx}": wandb.Image(plt)}, step=step)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
|
# attn_output: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||||
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
|
mem_attn_output = torch.einsum("bhsn, bhsnd-> bhsd", mem_attn_weights, knn_value)
|
||||||
|
# mem_attn_output = torch.matmul(mem_attn_weights, knn_value)
|
||||||
local_attn_output = torch.matmul(local_attn_weights, local_value)
|
local_attn_output = torch.matmul(local_attn_weights, local_value)
|
||||||
|
|
||||||
# TODO: do we need flamingo style gating
|
# TODO: do we need flamingo style gating
|
||||||
@ -524,8 +593,6 @@ class LetheAttention(nn.Module):
|
|||||||
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
|
alpha=1.0 if self.memory else (torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor)
|
||||||
)
|
)
|
||||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||||
if self.memory:
|
|
||||||
attn_scores = attn_scores * self.scale.exp()
|
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_scores.dtype).min
|
mask_value = torch.finfo(attn_scores.dtype).min
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
@ -610,12 +677,15 @@ class LetheMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LetheLayer(nn.Module):
|
class LetheLayer(nn.Module):
|
||||||
def __init__(self, config, memory_attention=False, index=None, tracker=None):
|
def __init__(self, config, memory_attention=False, layer_idx=None, index=None, tracker=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attention = LetheAttention(config, memory_attention=memory_attention, index=index, tracker=tracker)
|
self.attention = LetheAttention(config, memory_attention=memory_attention,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
index=index[layer_idx] if memory_attention else None,
|
||||||
|
tracker=tracker)
|
||||||
self.mlp = LetheMLP(config)
|
self.mlp = LetheMLP(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -676,8 +746,9 @@ class LetheModel(LethePreTrainedModel):
|
|||||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([LetheLayer(config,
|
self.layers = nn.ModuleList([LetheLayer(config,
|
||||||
memory_attention=i+1 == config.memory_attn_layer,
|
memory_attention=i+1 in config.memory_attn_layer,
|
||||||
index=index if i+1 == config.memory_attn_layer else None,
|
layer_idx=i,
|
||||||
|
index=index,
|
||||||
tracker=tracker)
|
tracker=tracker)
|
||||||
for i in range(config.num_hidden_layers)])
|
for i in range(config.num_hidden_layers)])
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
from .configuration_pythia_retro import PythiaRetroConfig
|
|
||||||
from .modeling_pythia_retro import PythiaRetroForCausalLM
|
|
288
gpt4all/train/pretrain_mem_retrieval.py
Normal file
288
gpt4all/train/pretrain_mem_retrieval.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
import os
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer, get_scheduler, AutoConfig
|
||||||
|
import torch
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from gpt4all.utils.read import read_config
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||||
|
from gpt4all.data.enwik8 import load_enwik8_dataloader
|
||||||
|
from torchmetrics import MeanMetric
|
||||||
|
from tqdm import tqdm
|
||||||
|
from gpt4all.models import LetheForCausalLM, LetheConfig
|
||||||
|
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
def format_metrics(metrics, split, prefix=""):
|
||||||
|
log = f"[{split}]" + prefix
|
||||||
|
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
|
||||||
|
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, index, pad_token_id, config, val_dataloader, main_process=False):
|
||||||
|
model.eval()
|
||||||
|
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
|
||||||
|
chunk_size = config["seq_len"]
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(val_dataloader, disable=not main_process):
|
||||||
|
seq_len = batch.shape[1]
|
||||||
|
for chunk_start in range(0, seq_len, chunk_size):
|
||||||
|
chunk_end = min(seq_len, chunk_start + chunk_size)
|
||||||
|
inputs = batch[:, chunk_start:chunk_end].to(model.device)
|
||||||
|
labels = inputs.clone()
|
||||||
|
outputs = model(input_ids=inputs,
|
||||||
|
attention_mask=inputs.ne(pad_token_id),
|
||||||
|
labels=labels,
|
||||||
|
log_attn_scores=False,
|
||||||
|
step=None,
|
||||||
|
save_kv=True,
|
||||||
|
)
|
||||||
|
loss = outputs.loss / config["segments"]
|
||||||
|
loss_values = accelerator.gather_for_metrics({"loss": loss.item()})
|
||||||
|
val_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
|
index.reset()
|
||||||
|
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train(accelerator, config):
|
||||||
|
set_seed(config['seed'])
|
||||||
|
|
||||||
|
accelerator.print(config)
|
||||||
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||||
|
# if no pad token, set it to eos
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
||||||
|
with accelerator.main_process_first():
|
||||||
|
train_dataloader, val_dataloader = load_enwik8_dataloader(config, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
if accelerator.state.deepspeed_plugin is not None:
|
||||||
|
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
|
"gradient_accumulation_steps"
|
||||||
|
]
|
||||||
|
|
||||||
|
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
|
||||||
|
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
|
||||||
|
# instead of decaying to zero, decay to ratio of min_lr / lr
|
||||||
|
accelerator.print(f"Total training steps: {total_num_steps}")
|
||||||
|
|
||||||
|
checkpoint = config["gradient_checkpointing"]
|
||||||
|
|
||||||
|
model_config = LetheConfig.from_pretrained(config["model_name"])
|
||||||
|
model_config.memory_attn_layer = config["memory_attn_layer"]
|
||||||
|
model_config.num_neighbors_to_retrieve = config["num_neighbors_to_retrieve"]
|
||||||
|
model_config.use_cache = False if checkpoint else True
|
||||||
|
|
||||||
|
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||||
|
index = BatchedMemory(config["batch_size"],
|
||||||
|
head_size,
|
||||||
|
config["num_memories_per_index"],
|
||||||
|
model_config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = LetheForCausalLM(model_config,
|
||||||
|
index=index,
|
||||||
|
tracker=accelerator.get_tracker("wandb"))
|
||||||
|
|
||||||
|
|
||||||
|
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
optimizer_cls = (
|
||||||
|
AdamW
|
||||||
|
if accelerator.state.deepspeed_plugin is None
|
||||||
|
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
else DummyOptim
|
||||||
|
)
|
||||||
|
|
||||||
|
# karpathy doesn't decay embeddding, maybe we should exclude
|
||||||
|
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
||||||
|
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
|
||||||
|
|
||||||
|
|
||||||
|
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
|
||||||
|
if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config:
|
||||||
|
if (
|
||||||
|
accelerator.state.deepspeed_plugin is None
|
||||||
|
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
):
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
name="cosine",
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
|
||||||
|
num_training_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scheduler = DummyScheduler(
|
||||||
|
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
|
||||||
|
)
|
||||||
|
model, optimizer, scheduler, train_dataloader, val_dataloader = accelerator.prepare(
|
||||||
|
model, optimizer, scheduler, train_dataloader, val_dataloader
|
||||||
|
)
|
||||||
|
use_scheduler = True
|
||||||
|
else:
|
||||||
|
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, val_dataloader
|
||||||
|
)
|
||||||
|
use_scheduler = False
|
||||||
|
|
||||||
|
# setup for saving training states in case preemption
|
||||||
|
if use_scheduler:
|
||||||
|
accelerator.register_for_checkpointing(scheduler)
|
||||||
|
|
||||||
|
if config["checkpoint"]:
|
||||||
|
accelerator.load_state(config["checkpoint"])
|
||||||
|
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||||
|
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||||
|
training_difference = os.path.splitext(path)[0]
|
||||||
|
resume_step = int(training_difference.replace("step_", ""))
|
||||||
|
accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||||
|
accelerator.print(f"Resuming from step {resume_step}")
|
||||||
|
|
||||||
|
# log gradients
|
||||||
|
if accelerator.is_main_process and config["wandb"]:
|
||||||
|
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||||
|
|
||||||
|
main_process = accelerator.is_main_process
|
||||||
|
|
||||||
|
chunk_size = config["seq_len"]
|
||||||
|
for epoch in range(config["num_epochs"]):
|
||||||
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
|
||||||
|
epoch_step = epoch * len(train_dataloader) + step * config["segments"]
|
||||||
|
seq_len = batch["input_ids"].shape[1]
|
||||||
|
model.train()
|
||||||
|
for i, chunk_start in enumerate(range(0, seq_len, chunk_size)):
|
||||||
|
curr_step = epoch_step + i
|
||||||
|
chunk_end = min(seq_len, chunk_start + chunk_size)
|
||||||
|
inputs = batch["input_ids"][:, chunk_start:chunk_end]
|
||||||
|
labels = inputs.clone()
|
||||||
|
labels[labels == tokenizer.pad_token_id] = -100
|
||||||
|
outputs = model(input_ids=inputs,
|
||||||
|
attention_mask=inputs.ne(tokenizer.pad_token_id),
|
||||||
|
labels=labels,
|
||||||
|
log_attn_scores=True,
|
||||||
|
step=curr_step,
|
||||||
|
save_kv=True,
|
||||||
|
)
|
||||||
|
loss = outputs.loss / config["segments"]
|
||||||
|
|
||||||
|
if config["wandb"]:
|
||||||
|
accelerator.log({"loss": loss}, step=curr_step)
|
||||||
|
|
||||||
|
# gather loss before backprop in case of gradient accumulation
|
||||||
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||||
|
train_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
|
loss = loss / gradient_accumulation_steps
|
||||||
|
accelerator.backward(loss)
|
||||||
|
|
||||||
|
# log LR in case something weird happens
|
||||||
|
if config["wandb"]:
|
||||||
|
if step > 0 and step % (config["log_lr_every"] ) == 0:
|
||||||
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
|
accelerator.log({"lr": lr}, step=curr_step)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
if use_scheduler:
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# reset index on batch end
|
||||||
|
index.reset()
|
||||||
|
|
||||||
|
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
|
||||||
|
# accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/step_{step}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
|
val_loss = evaluate(model, index, tokenizer.pad_token_id, config, val_dataloader, main_process=main_process)
|
||||||
|
|
||||||
|
log_train = {
|
||||||
|
"train_loss": train_loss.compute()
|
||||||
|
}
|
||||||
|
log_val = {
|
||||||
|
"val_loss": val_loss.compute(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if config["wandb"]:
|
||||||
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
|
accelerator.log({**log_train, **log_val}, step=curr_step)
|
||||||
|
|
||||||
|
accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}")
|
||||||
|
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
|
||||||
|
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
|
||||||
|
|
||||||
|
train_loss.reset()
|
||||||
|
|
||||||
|
accelerator.print(f"Epoch {epoch} finished")
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
if config["push_to_hub"]:
|
||||||
|
accelerator.print(f"Pushing to HF hub")
|
||||||
|
try:
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
accelerator.print(e)
|
||||||
|
accelerator.print(f"Failed to push to hub")
|
||||||
|
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/epoch_{epoch}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/final",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# parse arguments by reading in a config
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = read_config(args.config)
|
||||||
|
|
||||||
|
if config["wandb"]:
|
||||||
|
accelerator = Accelerator(log_with="wandb")
|
||||||
|
accelerator.init_trackers(
|
||||||
|
project_name=config["wandb_project_name"],
|
||||||
|
config=config,
|
||||||
|
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
accelerator = Accelerator()
|
||||||
|
|
||||||
|
train(accelerator, config=config)
|
@ -116,6 +116,7 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
if config["checkpoint"]:
|
if config["checkpoint"]:
|
||||||
accelerator.load_state(config["checkpoint"])
|
accelerator.load_state(config["checkpoint"])
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||||
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||||
training_difference = os.path.splitext(path)[0]
|
training_difference = os.path.splitext(path)[0]
|
||||||
@ -131,9 +132,12 @@ def train(accelerator, config):
|
|||||||
for epoch in range(config["num_epochs"]):
|
for epoch in range(config["num_epochs"]):
|
||||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||||
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
if config["wandb"]:
|
||||||
|
accelerator.log({"loss": loss}, step=curr_step)
|
||||||
|
|
||||||
# gather loss before backprop in case of gradient accumulation
|
# gather loss before backprop in case of gradient accumulation
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||||
@ -157,7 +161,13 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
if step > 0 and step % config["save_every"] == 0:
|
if step > 0 and step % config["save_every"] == 0:
|
||||||
curr_step = step + epoch * len(train_dataloader)
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/step_{curr_step}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
val_loss = evaluate(model, val_dataloader)
|
val_loss = evaluate(model, val_dataloader)
|
||||||
|
@ -11,7 +11,7 @@ from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
|
|||||||
from torchmetrics import MeanMetric
|
from torchmetrics import MeanMetric
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from gpt4all.models import LetheForCausalLM
|
from gpt4all.models import LetheForCausalLM
|
||||||
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
|
from gpt4all.models.lethe.modeling_lethe import BatchedMemory
|
||||||
import wandb
|
import wandb
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pyarrow import feather
|
from pyarrow import feather
|
||||||
@ -58,13 +58,15 @@ def evaluate(model, index, config, val_dataloader, main_process=False):
|
|||||||
qa_labels = batch["labels"]
|
qa_labels = batch["labels"]
|
||||||
outputs = model(input_ids=qa_inputs,
|
outputs = model(input_ids=qa_inputs,
|
||||||
labels=qa_labels,
|
labels=qa_labels,
|
||||||
|
save_kv=False
|
||||||
)
|
)
|
||||||
|
|
||||||
del memories
|
del memories
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
|
||||||
|
|
||||||
index.reset()
|
for ind in index.values():
|
||||||
|
ind.reset()
|
||||||
val_loss.update(loss_values["loss"])
|
val_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
|
per_example_loss = calculate_per_example_loss(outputs["logits"], qa_labels)
|
||||||
@ -110,16 +112,18 @@ def train(accelerator, config):
|
|||||||
model_config = AutoConfig.from_pretrained(config["model_name"])
|
model_config = AutoConfig.from_pretrained(config["model_name"])
|
||||||
|
|
||||||
head_size = model_config.hidden_size // model_config.num_attention_heads
|
head_size = model_config.hidden_size // model_config.num_attention_heads
|
||||||
index = MemoryIndex(head_size,
|
indices = {i - 1: BatchedMemory(config["batch_size"],
|
||||||
|
head_size,
|
||||||
config["num_memories_per_index"],
|
config["num_memories_per_index"],
|
||||||
model_config.num_attention_heads
|
model_config.num_attention_heads,
|
||||||
)
|
) for i in config["memory_attn_layer"]}
|
||||||
|
|
||||||
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
model = LetheForCausalLM.from_pretrained(config["model_name"],
|
||||||
revision=config['version'] if 'version' in config else None,
|
revision=config['version'] if 'version' in config else None,
|
||||||
use_cache=False if checkpoint else True,
|
use_cache=False if checkpoint else True,
|
||||||
memory_attn_layer=config["memory_attn_layer"],
|
memory_attn_layer=config["memory_attn_layer"],
|
||||||
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
|
||||||
index=index,
|
index=indices,
|
||||||
tracker=accelerator.get_tracker("wandb"),
|
tracker=accelerator.get_tracker("wandb"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -206,8 +210,10 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
qa_inputs = batch["input_ids"]
|
qa_inputs = batch["input_ids"]
|
||||||
|
attn_mask = batch["attention_mask"]
|
||||||
qa_labels = batch["labels"]
|
qa_labels = batch["labels"]
|
||||||
outputs = model(input_ids=qa_inputs,
|
outputs = model(input_ids=qa_inputs,
|
||||||
|
attention_mask=attn_mask,
|
||||||
labels=qa_labels,
|
labels=qa_labels,
|
||||||
log_attn_scores=True,
|
log_attn_scores=True,
|
||||||
step=curr_step,
|
step=curr_step,
|
||||||
@ -225,7 +231,8 @@ def train(accelerator, config):
|
|||||||
loss = loss / gradient_accumulation_steps
|
loss = loss / gradient_accumulation_steps
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
# !! don't reset index until after backwards pass
|
# !! don't reset index until after backwards pass
|
||||||
index.reset()
|
for index in indices.values():
|
||||||
|
index.reset()
|
||||||
# get gradient norm of all params
|
# get gradient norm of all params
|
||||||
|
|
||||||
# log LR in case something weird happens
|
# log LR in case something weird happens
|
||||||
@ -251,7 +258,7 @@ def train(accelerator, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
val_loss, loss_table = evaluate(model, index, config, val_dataloader, main_process=main_process)
|
val_loss, loss_table = evaluate(model, indices, config, val_dataloader, main_process=main_process)
|
||||||
|
|
||||||
local_rank = accelerator.process_index
|
local_rank = accelerator.process_index
|
||||||
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
|
feather.write_feather(loss_table, f"{config['output_dir']}/val_losses_step_{curr_step}_rank_{local_rank}.feather")
|
||||||
|
Loading…
Reference in New Issue
Block a user