diff --git a/gpt4all/data/retrieval_dataloader.py b/gpt4all/data/retrieval_dataloader.py index 75c87574..6a04737c 100644 --- a/gpt4all/data/retrieval_dataloader.py +++ b/gpt4all/data/retrieval_dataloader.py @@ -73,3 +73,75 @@ def load_retrieval_augmented_data(config, tokenizer, split="train", split_datase return dataloader + +def load_memory_augmented_data(config, tokenizer, split="train", split_dataset=True): + dataset_path = config["dataset_path"] + + if os.path.exists(dataset_path): + dataset = Dataset.load_from_disk(dataset_path) + else: + dataset = load_dataset(dataset_path, split=split) + + + question_col = config["q_column"] + answer_col = config["a_column"] + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # strip any unneccessary whitespace + # there's one question that's includes a ton of whitespace + dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True) + # in squad, the data is formatted where each ele in answers is a dict where the key text holds + # a list of the answer + dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True) + + dataset = dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col), + batched=True, + **kwargs + ) + + # tokenize contexts for each example + dataset = dataset.map( + lambda ele: {"retrieved_context": tokenizer(ele["context"], + return_tensors="pt", + padding="max_length", + truncation=True)["input_ids"]}, + batched=True, + **kwargs + ) + + columns_to_keep = ["input_ids", "labels", "retrieved_context"] + + col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] + dataset = dataset.remove_columns(col_names_to_rm) + + if split_dataset: + dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"]) + train_dataset, val_dataset = dataset["train"], dataset["test"] + + train_dataloader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return train_dataloader, val_dataloader + + else: + dataloader = DataLoader( + dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return dataloader diff --git a/gpt4all/models/lethe/test_lethe.py b/gpt4all/models/lethe/test_lethe.py index 4b274d53..13360319 100644 --- a/gpt4all/models/lethe/test_lethe.py +++ b/gpt4all/models/lethe/test_lethe.py @@ -2,19 +2,20 @@ import torch from gpt4all.models import LetheForCausalLM, LetheConfig from gpt4all.models.lethe.modeling_lethe import MemoryIndex from transformers import AutoTokenizer, AutoModel +from datasets import load_from_disk +from tqdm import tqdm # seed torch torch.manual_seed(0) -config = LetheConfig(num_hidden_layers=12, - hidden_size=1024, - intermediate_size=4096, +config = LetheConfig(num_hidden_layers=15, + hidden_size=2048, + intermediate_size=8192, num_attention_heads=8, - cross_attn_layer=9, - nn_index_path="/home/paperspace/gpt4all/gpt4all/train", - num_neighbors_stored=32768, - num_neighbors_to_retrieve=2, + memory_attn_layer=12, + num_neighbors_stored=6_000_000, + num_neighbors_to_retrieve=32, ) print("loaded config") @@ -22,10 +23,11 @@ print("loading model") dimension = config.max_position_embeddings * config.hidden_size head_size = config.hidden_size // config.num_attention_heads index = MemoryIndex(head_size, - 64_000, + 500_000, config.num_attention_heads ) model = LetheForCausalLM(config, index) +model.to("cuda:0") print("loaded model") @@ -33,12 +35,16 @@ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b") tokenizer.pad_token = tokenizer.eos_token tokenizer.model_max_length = 2048 -question = "Where was George Washington born?" -answer = "Virginia" +dataset = load_from_disk("/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train") -contexts = ["The Washington family was a wealthy Virginia planter family that had made its fortune through land speculation and the cultivation of tobacco.", - "George Washington was born on February 22, 1732,[b] at Popes Creek in Westmoreland County, in the British colony of Virginia,[18] and was the first of six children of Augustine and Mary Ball Washington.", - "His father was a justice of the peace and a prominent public figure who had four additional children from his first marriage to Jane Butler.[20] The family moved to Little Hunting Creek in 1735"] +item = dataset[0] + +question = item["question"] +answer = item["answers"]["text"][0] +print(f"question: {question}") +print(f"answer: {answer}") + +contexts = item["neighbor_text"] contexts_encoded = tokenizer(contexts, padding="max_length", truncation=True, return_tensors="pt") tokenized_input = tokenizer(question + "\n" + answer, return_tensors="pt", padding="max_length", truncation=True) @@ -53,9 +59,23 @@ labels[:-1] = -100 labels[-1, :question_len] = -100 +memory_mask = token_type_ids == 0 +# should be shape (num_memories, sequence_length) +memories = inputs[memory_mask] print("Running model") -outputs = model(input_ids=inputs, token_type_ids=token_type_ids, labels=labels) +model.eval() +with torch.no_grad(): + for chunk_start in tqdm(range(0, memories.shape[0], 32)): + chunk_end = min(memories.shape[0], chunk_start + 32) + mem_chunk = memories[chunk_start:chunk_end].to(model.device) + model(input_ids=mem_chunk, labels=None) + +model.train() + +qa_inputs = inputs[~memory_mask] +qa_labels = labels[~memory_mask] +outputs = model(input_ids=qa_inputs.to(model.device), labels=qa_labels.to(model.device)) print(outputs) print(outputs.logits.shape) diff --git a/gpt4all/train/train_mem_retrieval.py b/gpt4all/train/train_mem_retrieval.py new file mode 100644 index 00000000..bf56d612 --- /dev/null +++ b/gpt4all/train/train_mem_retrieval.py @@ -0,0 +1,284 @@ +import os +from transformers import AutoTokenizer, get_scheduler, AutoConfig +import torch +from torch.optim import AdamW +from argparse import ArgumentParser +from gpt4all.utils.read import read_config +from accelerate import Accelerator +from accelerate.utils import DummyScheduler, DummyOptim, set_seed +from gpt4all.data.retrieval_dataloader import load_memory_augmented_data +from torchmetrics import MeanMetric +from tqdm import tqdm +from gpt4all.models import LetheForCausalLM +from gpt4all.models.lethe.modeling_lethe import MemoryIndex +import wandb + +torch.backends.cuda.matmul.allow_tf32 = True + +def format_metrics(metrics, split, prefix=""): + log = f"[{split}]" + prefix + log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()]) + + return log + + +def evaluate(model, config, val_dataloader, main_process=False): + model.eval() + val_loss = MeanMetric(nan_strategy="error").to(model.device) + + head_size = model.config.hidden_size // model.config.num_attention_heads + index = MemoryIndex(head_size, + config["num_memories_per_index"], + model.config.num_attention_heads + ) + + with torch.no_grad(): + for batch in tqdm(val_dataloader, disable=not main_process): + memories = batch["retrieved_context"] + + # need to set to eval so we don't do mem attn as it's slow + model.eval() + for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): + chunk_end = min(memories.shape[0], chunk_start + config["mem_chunk_size"]) + mem_chunk = memories[chunk_start:chunk_end] + model(input_ids=mem_chunk, labels=None, use_mem_attn=False) + + qa_inputs = batch["input_ids"] + qa_labels = batch["labels"] + outputs = model(input_ids=qa_inputs, + labels=qa_labels, + ) + loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) + + + val_loss.update(loss_values["loss"]) + index.reset() + + return val_loss + + +def train(accelerator, config): + set_seed(config['seed']) + + accelerator.print(config) + accelerator.print(f"Using {accelerator.num_processes} GPUs") + + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + # if no pad token, set it to eos + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + + with accelerator.main_process_first(): + train_dataloader, val_dataloader = load_memory_augmented_data(config, tokenizer) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + + accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") + total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] + # instead of decaying to zero, decay to ratio of min_lr / lr + accelerator.print(f"Total training steps: {total_num_steps}") + + checkpoint = config["gradient_checkpointing"] + + model_config = AutoConfig.from_pretrained(config["model_name"]) + + head_size = model_config.hidden_size // model_config.num_attention_heads + index = MemoryIndex(head_size, + config["num_memories_per_index"], + model_config.num_attention_heads + ) + model = LetheForCausalLM.from_pretrained(config["model_name"], + revision=config['version'] if 'version' in config else None, + use_cache=False if checkpoint else True, + memory_attn_layer=config["memory_attn_layer"], + num_neighbors_to_retrieve=config["num_neighbors_to_retrieve"], + index=index, + ) + + + accelerator.print(f"Training a {model.num_parameters():,} parameter model") + if checkpoint: + model.gradient_checkpointing_enable() + + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + + # karpathy doesn't decay embeddding, maybe we should exclude + # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s + optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) + + + # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler + if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config: + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + scheduler = get_scheduler( + name="cosine", + optimizer=optimizer, + num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, + num_training_steps=total_num_steps, + ) + else: + scheduler = DummyScheduler( + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] + ) + model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader, scheduler + ) + scheduler = True + else: + model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader + ) + scheduler = False + + + # setup for saving training states in case preemption + if scheduler: + accelerator.register_for_checkpointing(scheduler) + + if config["checkpoint"]: + accelerator.load_state(config["checkpoint"]) + accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") + path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) + training_difference = os.path.splitext(path)[0] + resume_step = int(training_difference.replace("step_", "")) + accelerator.skip_first_batches(train_dataloader, resume_step) + accelerator.print(f"Resuming from step {resume_step}") + + # log gradients + if accelerator.is_main_process and config["wandb"]: + wandb.watch(model, log_freq=config["log_grads_every"], log="all") + + main_process = accelerator.is_main_process + + for epoch in range(config["num_epochs"]): + train_loss = MeanMetric(nan_strategy="error").to(model.device) + for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): + curr_step = step + epoch * len(train_dataloader) + memories = batch["retrieved_context"] + + # need to set to eval so we don't do mem attn as it's slow + model.eval() + with torch.no_grad(): + for chunk_start in range(0, memories.shape[0], config["mem_chunk_size"]): + chunk_end = min(memories.shape[0], chunk_start + 32) + mem_chunk = memories[chunk_start:chunk_end] + model(input_ids=mem_chunk, labels=None, use_mem_attn=False) + + model.train() + qa_inputs = batch["input_ids"] + qa_labels = batch["labels"] + outputs = model(input_ids=qa_inputs, + labels=qa_labels, + ) + loss = outputs.loss + + index.reset() + + # gather loss before backprop in case of gradient accumulation + loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) + train_loss.update(loss_values["loss"]) + + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + # get gradient norm of all params + + # log LR in case something weird happens + if config["wandb"]: + if step > 0 and step % (config["log_lr_every"] ) == 0: + lr = optimizer.param_groups[0]["lr"] + accelerator.log({"lr": lr}, step=curr_step) + + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + if scheduler: + scheduler.step() + optimizer.zero_grad() + + if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0: + accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + + if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): + val_loss = evaluate(model, config, val_dataloader, main_process=main_process) + + log_train = { + "train_loss": train_loss.compute() + } + log_val = { + "val_loss": val_loss.compute(), + } + + if config["wandb"]: + curr_step = step + epoch * len(train_dataloader) + accelerator.log({**log_train, **log_val}, step=curr_step) + + accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}") + accelerator.print(format_metrics(log_train, "train", f" step {step} ")) + accelerator.print(format_metrics(log_val, "val", f" step {step} ")) + + train_loss.reset() + + accelerator.print(f"Epoch {epoch} finished") + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + if config["push_to_hub"]: + accelerator.print(f"Pushing to HF hub") + try: + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + + except Exception as e: + accelerator.print(e) + accelerator.print(f"Failed to push to hub") + + unwrapped_model.save_pretrained( + f"{config['output_dir']}/epoch_{epoch}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/final", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.end_training() + + + +if __name__ == "__main__": + # parse arguments by reading in a config + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + if config["wandb"]: + accelerator = Accelerator(log_with="wandb") + accelerator.init_trackers( + project_name=config["wandb_project_name"], + config=config, + init_kwargs={"wandb": {"entity": config["wandb_entity"]}}, + ) + else: + accelerator = Accelerator() + + train(accelerator, config=config)