From 240feae27728d1d526c502d10cd506458feff376 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 23 Apr 2023 20:02:22 +0000 Subject: [PATCH] added initial files for dataset prep and ingestion for gpt4all jr --- configs/train/finetune_gptjr.yaml | 41 +++++ gpt4all/index/build_index.py | 14 +- gpt4all/index/embed.py | 20 ++- gpt4all/index/embed_texts.py | 28 ++- gpt4all/index/prep_index_for_train.py | 58 ++++++ gpt4all/index/test_load_index.py | 0 gpt4all/train/train_r.py | 246 ++++++++++++++++++++++++++ gpt4all/utils/data.py | 65 +++++++ 8 files changed, 454 insertions(+), 18 deletions(-) create mode 100644 configs/train/finetune_gptjr.yaml create mode 100644 gpt4all/index/prep_index_for_train.py create mode 100644 gpt4all/index/test_load_index.py create mode 100644 gpt4all/train/train_r.py diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml new file mode 100644 index 00000000..291cf841 --- /dev/null +++ b/configs/train/finetune_gptjr.yaml @@ -0,0 +1,41 @@ +# model/tokenizer +model_name: "nomic-ai/gpt4all-j" +tokenizer_name: "nomic-ai/gpt4all-j" +version: 'v1.2-jazzy' +gradient_checkpointing: true +save_name: # CHANGE + +# dataset +streaming: false +num_proc: 64 +dataset_path: "squad" +max_length: 1024 +batch_size: 32 + +#index +index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" +index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text" +index_space: "cosine" +index_dim: 384 +query_embedding_field: 'question' + +# train dynamics +lr: 2.0e-5 +min_lr: 0 +weight_decay: 0.0 +eval_every: 500 +eval_steps: 105 +save_every: 500 +log_grads_every: 100 +output_dir: # CHANGE +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 2 + +# logging +wandb: false +wandb_entity: # CHANGE +wandb_project_name: # CHANGE +seed: 42 + diff --git a/gpt4all/index/build_index.py b/gpt4all/index/build_index.py index 7e61e2ac..566f5abd 100644 --- a/gpt4all/index/build_index.py +++ b/gpt4all/index/build_index.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser import hnswlib import pyarrow as pa import pyarrow.compute as pc +from tqdm import tqdm def parse_args(): @@ -42,6 +43,7 @@ def join(original_ds, embedded_ds): mask = pc.is_in(embed_table["index"], value_set=pa.array(indices, pa.int32())) filtered_table = embed_table.filter(mask) + import pdb; pdb.set_trace() # sort to make sure we're adding in right order filtered_table = filtered_table.sort_by("id") @@ -60,6 +62,8 @@ def join(original_ds, embedded_ds): def build_index(orig_path, embed_folder_path, index_path): if not os.path.exists(orig_path + "_embedded_with_text"): + # TODO: this doesn't work for large datasets! + # just convert to pandas and do this manually ds = Dataset.load_from_disk(orig_path) embed_ds = concat_embedded(embed_folder_path) print("Concatenated embeddings") @@ -79,7 +83,15 @@ def build_index(orig_path, embed_folder_path, index_path): # not sure what we should set M and ef_construction to index.init_index(max_elements=len(ds), M=64, ef_construction=200) print("Adding items") - index.add_items(ds["embedding"], ds["index"]) + chunk_size = 50_000 + num_chunks = len(ds) // chunk_size + progbar = tqdm(total=num_chunks) + start = 0 + while start < len(ds): + chunk = ds[start:start + chunk_size] + index.add_items(chunk["embedding"], chunk["index"], num_threads=64) + progbar.update(1) + start += chunk_size print("Saving index") index.save_index(index_path + ".bin") diff --git a/gpt4all/index/embed.py b/gpt4all/index/embed.py index 6ac75a3b..2cda33b2 100644 --- a/gpt4all/index/embed.py +++ b/gpt4all/index/embed.py @@ -45,16 +45,11 @@ class Embedder: return tokenized_text + def tokenize(self, text): + return self.tokenizer(text, return_tensors="pt", truncation=True, padding="max_length") + def __call__(self, batch): - if isinstance(batch, dict): - outputs = self.embedder( - input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] - ) - embedding = self._mean_pool(outputs, batch["attention_mask"]) - - return {"id": batch["id"], "embedding": embedding} - - elif isinstance(batch, str): + if isinstance(batch, str): tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True) return self._mean_pool( self.embedder( @@ -63,6 +58,13 @@ class Embedder: ), tokenized["attention_mask"], ) + else: + outputs = self.embedder( + input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] + ) + embedding = self._mean_pool(outputs, batch["attention_mask"]) + + return {"id": batch["id"], "embedding": embedding} def to(self, device): self.embedder.to(device) diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py index f130217d..a139f5df 100644 --- a/gpt4all/index/embed_texts.py +++ b/gpt4all/index/embed_texts.py @@ -9,6 +9,7 @@ from transformers import BatchEncoding from tqdm import tqdm import numpy as np import torch +from datasets import load_dataset class PadCollateInputs: @@ -29,20 +30,29 @@ class PadCollateInputs: return padded_inputs -def embed_texts(ds_path, batch_size): +def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False): rank0_print(f"World size: {dist.get_world_size()}") - - dataset = Dataset.load_from_disk(ds_path) + dataset = load_dataset(f"{ds_path}", split="train") rank0_print(f"Dataset size: {len(dataset)}") - dataset = dataset.remove_columns(["url", "title", "text"]) + + model = Embedder() + + dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64) + + columns_to_keep = ["input_ids", "attention_mask"] + #to_remove = [e for e in dataset.column_names if not e in columns_to_keep] + print('cols: ', dataset.column_names) + #dataset = dataset.remove_columns(to_remove) + + #dataset = Dataset.load_from_disk(ds_path) + #dataset = dataset.remove_columns(["url", "title", "text"]) dataset = dataset.with_format("torch") num_processes = dist.get_world_size() local_rank = dist.get_rank() - model = Embedder() - collator = PadCollateInputs(model.tokenizer) + #collator = PadCollateInputs(model.tokenizer) sampler = DistributedSampler( dataset, @@ -53,7 +63,7 @@ def embed_texts(ds_path, batch_size): ) dataloader = DataLoader( dataset, - collate_fn=collator, + # collate_fn=collator, batch_size=batch_size, sampler=sampler, drop_last=False, @@ -77,7 +87,9 @@ def embed_texts(ds_path, batch_size): # feeling lazy, don't want to wait for all_gather to finish # will load and concat in a single process in another script - ds.save_to_disk(f"embedded/{ds_path}_embedded_rank_{local_rank}") + if save_to_disk: + ds.save_to_disk(f"{ds_path}_embedded/{ds_path}_embedded_rank_{local_rank}") + return ds def main(): diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py new file mode 100644 index 00000000..dcb90da3 --- /dev/null +++ b/gpt4all/index/prep_index_for_train.py @@ -0,0 +1,58 @@ +import os +import hnswlib +import numpy as np +from datasets import Dataset +import torch.distributed as dist +from datasets import load_dataset +from argparse import ArgumentParser +from gpt4all.utils.read import read_config +from gpt4all.index.embed_texts import embed_texts + +CHUNK_SIZE = 1024 +K = 5 + + +if __name__ == "__main__": + + dist.init_process_group("nccl") + + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + #load index + index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) + index.load_index(config['index_path']) + + #load query dataset + ds_path = config['dataset_path'] + + #load retrieval dataset + retrieval_dataset = Dataset.load_from_disk(config['index_database']) + print(type(retrieval_dataset._data)) + raise + + #vectorize queries + query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" + if not os.path.exists(query_vector_path): + print('Embedding dataset...') + ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False) + ds.save_to_disk(query_vector_path) + else: + print('Found cached embedding dataset!') + ds = Dataset.load_from_disk(query_vector_path) + + #search the index for each query + for chunk_start in range(0, len(ds), CHUNK_SIZE): + chunk_end = chunk_start + CHUNK_SIZE + chunk = ds[chunk_start:chunk_end] + query_vectors = np.array(chunk['embedding']) + print(query_vectors.shape) + neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) + raise + + #get the embeddings for each of the neighbor ids + diff --git a/gpt4all/index/test_load_index.py b/gpt4all/index/test_load_index.py new file mode 100644 index 00000000..e69de29b diff --git a/gpt4all/train/train_r.py b/gpt4all/train/train_r.py new file mode 100644 index 00000000..9254d6ef --- /dev/null +++ b/gpt4all/train/train_r.py @@ -0,0 +1,246 @@ +import os +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM +import torch +from torch.optim import AdamW +from argparse import ArgumentParser +from gpt4all.utils.read import read_config +from accelerate import Accelerator +from accelerate.utils import DummyScheduler, DummyOptim, set_seed +from peft import get_peft_model, LoraConfig, TaskType +from gpt4all.utils.data import load_data, load_retrieval_augmented_data +from torchmetrics import MeanMetric +from tqdm import tqdm +from gpt4all.models import GPTJRForCausalLM +import wandb + +torch.backends.cuda.matmul.allow_tf32 = True + +def format_metrics(metrics, split, prefix=""): + log = f"[{split}]" + prefix + log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()]) + + return log + + +def evaluate(model, val_dataloader): + model.eval() + val_loss = MeanMetric(nan_strategy="error").to(model.device) + + with torch.no_grad(): + for batch in tqdm(val_dataloader): + loss = model(**batch).loss + + loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) + + val_loss.update(loss_values["loss"]) + + return val_loss + + +def train(accelerator, config): + set_seed(config['seed']) + + accelerator.print(config) + accelerator.print(f"Using {accelerator.num_processes} GPUs") + + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + # if no pad token, set it to eos + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + + with accelerator.main_process_first(): + + if 'index_path' in config: + train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) + else: + train_dataloader, val_dataloader = load_data(config, tokenizer) + + + checkpoint = config["gradient_checkpointing"] + #ensures back compat with non retrieval models + if 'index_path' in config: + model = GPTJRForCausalLM.from_pretrained(config["model_name"], + revision=config['version'], + use_cache=False if checkpoint else True, + trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained(config["model_name"], + use_cache=False if checkpoint else True, + trust_remote_code=True) + + if checkpoint: + model.gradient_checkpointing_enable() + + if config["lora"]: + peft_config = LoraConfig( + # should R be configurable? + task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + + # karpathy doesn't decay embeddding, maybe we should exclude + # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s + optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + + # decay to min_lr instead of 0 + lr_ratio = config["min_lr"] / config["lr"] + accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") + total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] + # instead of decaying to zero, decay to ratio of min_lr / lr + total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"] + accelerator.print(f"Total training steps: {total_num_steps}") + + # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + scheduler = get_scheduler( + name="cosine", + optimizer=optimizer, + num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, + num_training_steps=total_num_steps, + ) + else: + scheduler = DummyScheduler( + optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + ) + + model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader, scheduler + ) + + # setup for saving training states in case preemption + accelerator.register_for_checkpointing(scheduler) + + if config["checkpoint"]: + accelerator.load_state(config["checkpoint"]) + accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") + path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) + training_difference = os.path.splitext(path)[0] + resume_step = int(training_difference.replace("step_", "")) + accelerator.skip_first_batches(train_dataloader, resume_step) + accelerator.print(f"Resuming from step {resume_step}") + + + # log gradients + if accelerator.is_main_process and config["wandb"]: + wandb.watch(model, log_freq=config["log_grads_every"], log="all") + + for epoch in range(config["num_epochs"]): + train_loss = MeanMetric(nan_strategy="error").to(model.device) + for step, batch in enumerate(tqdm(train_dataloader)): + model.train() + outputs = model(**batch) + loss = outputs.loss + + # gather loss before backprop in case of gradient accumulation + loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) + train_loss.update(loss_values["loss"]) + + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + # get gradient norm of all params + + # log LR in case something weird happens + if step > 0 and step % (config["eval_every"] // 10) == 0: + if config["wandb"]: + curr_step = step + epoch * len(train_dataloader) + accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) + + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + + if step > 0 and step % config["save_every"] == 0: + curr_step = step + epoch * len(train_dataloader) + accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + + if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): + val_loss = evaluate(model, val_dataloader) + + log_train = { + "train_loss": train_loss.compute() + } + log_val = { + "val_loss": val_loss.compute() + } + + if config["wandb"]: + curr_step = step + epoch * len(train_dataloader) + accelerator.log({**log_train, **log_val}, step=curr_step) + + accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") + accelerator.print(format_metrics(log_train, "train", f" step {step} ")) + accelerator.print(format_metrics(log_val, "val", f" step {step} ")) + + train_loss.reset() + + accelerator.print(f"Epoch {epoch} finished") + accelerator.print(f"Pushing to HF hub") + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + try: + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + + except Exception as e: + accelerator.print(e) + accelerator.print(f"Failed to push to hub") + + unwrapped_model.save_pretrained( + f"{config['output_dir']}/epoch_{epoch}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/final", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.end_training() + + + +if __name__ == "__main__": + # parse arguments by reading in a config + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + if config["wandb"]: + accelerator = Accelerator(log_with="wandb") + accelerator.init_trackers( + project_name=config["wandb_project_name"], + config=config, + init_kwargs={"wandb": {"entity": config["wandb_entity"]}}, + ) + else: + accelerator = Accelerator() + + train(accelerator, config=config) diff --git a/gpt4all/utils/data.py b/gpt4all/utils/data.py index b55a589a..3216441f 100644 --- a/gpt4all/utils/data.py +++ b/gpt4all/utils/data.py @@ -2,6 +2,7 @@ import glob import torch from datasets import load_dataset import os +import hnswlib from torch.utils.data import DataLoader from transformers import DefaultDataCollator @@ -116,6 +117,70 @@ def load_data(config, tokenizer): return train_dataloader, val_dataloader +def load_retrieval_augmented_data(config, tokenizer): + dataset_path = config["dataset_path"] + index_path = config['index_path'] + + #TODO this should precache at some point + index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) + index.load_index(index_path) + + if os.path.exists(dataset_path): + if os.path.isdir(dataset_path): + files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) + else: + files = [dataset_path] + + print(f"Reading files {files}") + + dataset = load_dataset("json", data_files=files, split="train") + + else: + dataset = load_dataset(dataset_path, split="train") + + dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) + + train_dataset, val_dataset = dataset["train"], dataset["test"] + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # tokenize inputs and return labels and attention mask + train_dataset = train_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + remove_columns=["source", "prompt"], + **kwargs + ) + val_dataset = val_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + remove_columns=["source", "prompt"], + **kwargs + ) + + train_dataset = train_dataset.with_format("torch") + val_dataset = val_dataset.with_format("torch") + + # create dataloader with default data collator since we already have labels + + train_dataloader = DataLoader( + train_dataset, + collate_fn=DefaultDataCollator(), + batch_size=config["batch_size"], + ) + + val_dataloader = DataLoader( + val_dataset, + collate_fn=DefaultDataCollator(), + batch_size=config["batch_size"], + ) + + return train_dataloader, val_dataloader + + def load_data_for_inference(config, tokenizer): dataset_path = config["dataset_path"]