wip: train of mem attn

This commit is contained in:
Zach Nussbaum 2023-05-18 19:44:30 +00:00
parent 5f28e60a9a
commit 3677935ce8
3 changed files with 390 additions and 14 deletions

View File

@ -73,3 +73,75 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase
return dataloader
def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=True):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
dataset = Dataset.load_from_disk(dataset_path)
else:
dataset = load_dataset(dataset_path, split=split)
question_col = config["q_column"]
answer_col = config["a_column"]
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# strip any unneccessary whitespace
# there's one question that's includes a ton of whitespace
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
# a list of the answer
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
dataset = dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
batched=True,
**kwargs
)
# tokenize contexts for each example
dataset = dataset.map(
lambda ele: {"retrieved_context": tokenizer(ele["context"],
return_tensors="pt",
padding="max_length",
truncation=True)["input_ids"]},
batched=True,
**kwargs
)
columns_to_keep = ["input_ids", "labels", "retrieved_context"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
if split_dataset:
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return train_dataloader, val_dataloader
else:
dataloader = DataLoader(
dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return dataloader

View File

@ -2,19 +2,20 @@ import torch
from gpt4all.models import LetheForCausalLM, LetheConfig
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
from transformers import AutoTokenizer, AutoModel
from datasets import load_from_disk
from tqdm import tqdm
# seed torch
torch.manual_seed(0)
config = LetheConfig(num_hidden_layers=12,
hidden_size=1024,
intermediate_size=4096,
config = LetheConfig(num_hidden_layers=15,
hidden_size=2048,
intermediate_size=8192,
num_attention_heads=8,
cross_attn_layer=9,
nn_index_path="/home/paperspace/gpt4all/gpt4all/train",
num_neighbors_stored=32768,
num_neighbors_to_retrieve=2,
memory_attn_layer=12,
num_neighbors_stored=6_000_000,
num_neighbors_to_retrieve=32,
)
print("loaded config")
@ -22,10 +23,11 @@ print("loading model")
dimension = config.max_position_embeddings * config.hidden_size
head_size = config.hidden_size // config.num_attention_heads
index = MemoryIndex(head_size,
64_000,
500_000,
config.num_attention_heads
)
model = LetheForCausalLM(config, index)
model.to("cuda:0")
print("loaded model")
@ -33,12 +35,16 @@ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 2048
question = "Where was George Washington born?"
answer = "Virginia"
dataset = load_from_disk("/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train")
contexts = ["The Washington family was a wealthy Virginia planter family that had made its fortune through land speculation and the cultivation of tobacco.",
"George Washington was born on February 22, 1732,[b] at Popes Creek in Westmoreland County, in the British colony of Virginia,[18] and was the first of six children of Augustine and Mary Ball Washington.",
"His father was a justice of the peace and a prominent public figure who had four additional children from his first marriage to Jane Butler.[20] The family moved to Little Hunting Creek in 1735"]
item = dataset[0]
question = item["question"]
answer = item["answers"]["text"][0]
print(f"question: {question}")
print(f"answer: {answer}")
contexts = item["neighbor_text"]
contexts_encoded = tokenizer(contexts, padding="max_length", truncation=True, return_tensors="pt")
tokenized_input = tokenizer(question + "\n" + answer, return_tensors="pt", padding="max_length", truncation=True)
@ -53,9 +59,23 @@ labels[:-1] = -100
labels[-1, :question_len] = -100
memory_mask = token_type_ids == 0
# should be shape (num_memories, sequence_length)
memories = inputs[memory_mask]
print("Running model")
outputs = model(input_ids=inputs, token_type_ids=token_type_ids, labels=labels)
model.eval()
with torch.no_grad():
for chunk_start in tqdm(range(0, memories.shape[0], 32)):
chunk_end = min(memories.shape[0], chunk_start + 32)
mem_chunk = memories[chunk_start:chunk_end].to(model.device)
model(input_ids=mem_chunk, labels=None)
model.train()
qa_inputs = inputs[~memory_mask]
qa_labels = labels[~memory_mask]
outputs = model(input_ids=qa_inputs.to(model.device), labels=qa_labels.to(model.device))
print(outputs)
print(outputs.logits.shape)

View File

@ -0,0 +1,284 @@
import os
from transformers import AutoTokenizer, get_scheduler, AutoConfig
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
from gpt4all.utils.read import read_config
from accelerate import Accelerator
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
from gpt4all.data.retrieval_dataloader import load_memory_augmented_data
from torchmetrics import MeanMetric
from tqdm import tqdm
from gpt4all.models import LetheForCausalLM
from gpt4all.models.lethe.modeling_lethe import MemoryIndex
import wandb
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
return log
def evaluate(model, config, val_dataloader, main_process=False):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
head_size = model.config.hidden_size // model.config.num_attention_heads
index = MemoryIndex(head_size,
config["num_memories_per_index"],
model.config.num_attention_heads
)
with torch.no_grad():
for batch in tqdm(val_dataloader, disable=not main_process):
memories = batch["retrieved_context"]
# need to set to eval so we don't do mem attn as it's slow
model.eval()
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"])
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs,
labels=qa_labels,
)
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
val_loss.update(loss_values["loss"])
index.reset()
return val_loss
def train(accelerator, config):
set_seed(config['seed'])
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer)
if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
# instead of decaying to zero, decay to ratio of min_lr / lr
accelerator.print(f"Total training steps: {total_num_steps}")
checkpoint = config["gradient_checkpointing"]
model_config = AutoConfig.from_pretrained(config["model_name"])
head_size = model_config.hidden_size // model_config.num_attention_heads
index = MemoryIndex(head_size,
config["num_memories_per_index"],
model_config.num_attention_heads
)
model = LetheForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True,
memory_attn_layer=config["memory_attn_layer"],
num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"],
index=index,
)
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
if checkpoint:
model.gradient_checkpointing_enable()
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config:
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader, scheduler
)
scheduler = True
else:
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader
)
scheduler = False
# setup for saving training states in case preemption
if scheduler:
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
main_process = accelerator.is_main_process
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
curr_step = step + epoch * len(train_dataloader)
memories = batch["retrieved_context"]
# need to set to eval so we don't do mem attn as it's slow
model.eval()
with torch.no_grad():
for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]):
chunk_end = min(memories.shape[0], chunk_start + 32)
mem_chunk = memories[chunk_start:chunk_end]
model(input_ids=mem_chunk, labels=None, use_mem_attn=False)
model.train()
qa_inputs = batch["input_ids"]
qa_labels = batch["labels"]
outputs = model(input_ids=qa_inputs,
labels=qa_labels,
)
loss = outputs.loss
index.reset()
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# get gradient norm of all params
# log LR in case something weird happens
if config["wandb"]:
if step > 0 and step % (config["log_lr_every"] ) == 0:
lr = optimizer.param_groups[0]["lr"]
accelerator.log({"lr": lr}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, config, val_dataloader, main_process=main_process)
log_train = {
"train_loss": train_loss.compute()
}
log_val = {
"val_loss": val_loss.compute(),
}
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if config["push_to_hub"]:
accelerator.print(f"Pushing to HF hub")
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()
if __name__ == "__main__":
# parse arguments by reading in a config
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
if config["wandb"]:
accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=config["wandb_project_name"],
config=config,
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
)
else:
accelerator = Accelerator()
train(accelerator, config=config)