diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml index a03f8ff0..352487c7 100644 --- a/configs/train/finetune_gptjr.yaml +++ b/configs/train/finetune_gptjr.yaml @@ -1,41 +1,40 @@ # model/tokenizer -model_name: "nomic-ai/gpt4all-j" -tokenizer_name: "nomic-ai/gpt4all-j" -version: 'v1.2-jazzy' +model_name: "EleutherAI/gpt-j-6B" +tokenizer_name: "EleutherAI/gpt-j-6B" +version: null gradient_checkpointing: true -save_name: # CHANGE +save_name: "nomic-ai/gpt-jr-decay-alpha" +encoder_dim: 384 # dataset streaming: false num_proc: 64 -dataset_path: "squad" +dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train" max_length: 1024 batch_size: 32 +pct_test: 0.05 +q_column: "question" +a_column: "answers" +encoder_column: "neighbor_embeddings" -#index -index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" -index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text" -index_space: "cosine" -index_dim: 384 -query_embedding_field: 'question' # train dynamics -lr: 2.0e-5 +lr: 1.0e-4 min_lr: 0 weight_decay: 0.0 -eval_every: 500 -eval_steps: 105 +eval_every: 50 save_every: 500 log_grads_every: 100 -output_dir: # CHANGE +log_lr_every: 10 +output_dir: "ckpts/decay_alpha" checkpoint: null lora: false warmup_steps: 500 -num_epochs: 2 +num_epochs: 5 # logging -wandb: false -wandb_entity: # CHANGE -wandb_project_name: # CHANGE +wandb: true +wandb_entity: gpt4all +wandb_project_name: retrieval seed: 42 diff --git a/gpt4all/data/instruction_tuning_dataloader.py b/gpt4all/data/instruction_tuning_dataloader.py index 37fef4d2..5803ea76 100644 --- a/gpt4all/data/instruction_tuning_dataloader.py +++ b/gpt4all/data/instruction_tuning_dataloader.py @@ -5,58 +5,7 @@ import os import hnswlib from torch.utils.data import DataLoader from transformers import DefaultDataCollator - - - -def tokenize_inputs(config, tokenizer, examples): - max_length = config["max_length"] - - # hacky backward compatible - different_eos = tokenizer.eos_token != "" - out = {"labels": [], "input_ids": []} - for prompt, response in zip(examples["prompt"], examples["response"]): - if different_eos: - if response.count(" \n") > 0: - response = response.replace(" \n", f"{tokenizer.eos_token} \n") - - prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0]) - - # hack if our prompt is super long - # we need to include some labels so we arbitrarily trunacate at max_length // 2 - # if the length is too long - if prompt_len >= max_length // 2: - # if prompt is too long, truncate - # but make sure to truncate to at max 1024 tokens - new_len = min(max_length // 2, len(prompt) // 2) - prompt = prompt[:new_len] - # get new prompt length - prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() - - assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}" - - input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, - truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() - - labels = input_tokens.clone() - labels[:prompt_len] = -100 - if len(labels) < max_length: - # pad to max_length with -100 - labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) - - assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}" - - if (labels == -100).sum() == len(labels) - 1: - print(prompt) - print(response) - raise - - input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] - out["labels"].append(labels) - out["input_ids"].append(input_tokens) - - out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} - - return out +from .preprocess import tokenize_inputs def load_data(config, tokenizer): @@ -86,13 +35,13 @@ def load_data(config, tokenizer): # tokenize inputs and return labels and attention mask train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, remove_columns=["source", "prompt"], **kwargs ) val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, remove_columns=["source", "prompt"], **kwargs @@ -154,12 +103,12 @@ def load_data_for_inference(config, tokenizer): # tokenize inputs and return labels and attention mask train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, **kwargs ) val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, **kwargs ) diff --git a/gpt4all/train/train_r.py b/gpt4all/train/train_retrieval.py similarity index 81% rename from gpt4all/train/train_r.py rename to gpt4all/train/train_retrieval.py index 9254d6ef..1e31f01d 100644 --- a/gpt4all/train/train_r.py +++ b/gpt4all/train/train_retrieval.py @@ -7,11 +7,13 @@ from gpt4all.utils.read import read_config from accelerate import Accelerator from accelerate.utils import DummyScheduler, DummyOptim, set_seed from peft import get_peft_model, LoraConfig, TaskType -from gpt4all.utils.data import load_data, load_retrieval_augmented_data +from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data from torchmetrics import MeanMetric from tqdm import tqdm from gpt4all.models import GPTJRForCausalLM +from gpt4all.train.metrics import f1_score, exact_match_score import wandb +import torch.distributed as dist torch.backends.cuda.matmul.allow_tf32 = True @@ -22,15 +24,18 @@ def format_metrics(metrics, split, prefix=""): return log -def evaluate(model, val_dataloader): +def evaluate(model, val_dataloader, step, main_process=False): model.eval() val_loss = MeanMetric(nan_strategy="error").to(model.device) with torch.no_grad(): - for batch in tqdm(val_dataloader): - loss = model(**batch).loss + for batch in tqdm(val_dataloader, disable=not main_process): + outputs = model(input_ids=batch["input_ids"], + labels=batch["labels"], + encoder_hidden_states=batch["encoder_hidden_states"], + step=step) + loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) - loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) val_loss.update(loss_values["loss"]) @@ -50,20 +55,18 @@ def train(accelerator, config): with accelerator.main_process_first(): - - if 'index_path' in config: - train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) - else: - train_dataloader, val_dataloader = load_data(config, tokenizer) + train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) checkpoint = config["gradient_checkpointing"] #ensures back compat with non retrieval models - if 'index_path' in config: - model = GPTJRForCausalLM.from_pretrained(config["model_name"], - revision=config['version'], - use_cache=False if checkpoint else True, - trust_remote_code=True) + if 'encoder_dim' in config: + with accelerator.main_process_first(): + model = GPTJRForCausalLM.from_pretrained(config["model_name"], + revision=config['version'] if 'version' in config else None, + use_cache=False if checkpoint else True, + encoder_dim=config["encoder_dim"], + ) else: model = AutoModelForCausalLM.from_pretrained(config["model_name"], use_cache=False if checkpoint else True, @@ -117,13 +120,14 @@ def train(accelerator, config): ) else: scheduler = DummyScheduler( - optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] ) model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader, scheduler ) + # setup for saving training states in case preemption accelerator.register_for_checkpointing(scheduler) @@ -141,11 +145,16 @@ def train(accelerator, config): if accelerator.is_main_process and config["wandb"]: wandb.watch(model, log_freq=config["log_grads_every"], log="all") + main_process = accelerator.is_main_process + for epoch in range(config["num_epochs"]): train_loss = MeanMetric(nan_strategy="error").to(model.device) - for step, batch in enumerate(tqdm(train_dataloader)): + for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): model.train() - outputs = model(**batch) + outputs = model(input_ids=batch["input_ids"], + labels=batch["labels"], + encoder_hidden_states=batch["encoder_hidden_states"], + step=step) loss = outputs.loss # gather loss before backprop in case of gradient accumulation @@ -157,8 +166,8 @@ def train(accelerator, config): # get gradient norm of all params # log LR in case something weird happens - if step > 0 and step % (config["eval_every"] // 10) == 0: - if config["wandb"]: + if config["wandb"]: + if step > 0 and step % (config["log_lr_every"] ) == 0: curr_step = step + epoch * len(train_dataloader) accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) @@ -173,13 +182,14 @@ def train(accelerator, config): accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): - val_loss = evaluate(model, val_dataloader) + curr_step = step + epoch * len(train_dataloader) + val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process) log_train = { "train_loss": train_loss.compute() } log_val = { - "val_loss": val_loss.compute() + "val_loss": val_loss.compute(), } if config["wandb"]: