From 80d810322a42b1eb49b9addfd129c41c1426943a Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:38:01 +0000 Subject: [PATCH] fix: lr schedule --- .../instruction_tuning_dataloader.py} | 63 ------------------- gpt4all/train/train.py | 6 +- 2 files changed, 3 insertions(+), 66 deletions(-) rename gpt4all/{utils/data.py => data/instruction_tuning_dataloader.py} (75%) diff --git a/gpt4all/utils/data.py b/gpt4all/data/instruction_tuning_dataloader.py similarity index 75% rename from gpt4all/utils/data.py rename to gpt4all/data/instruction_tuning_dataloader.py index 3216441f..37fef4d2 100644 --- a/gpt4all/utils/data.py +++ b/gpt4all/data/instruction_tuning_dataloader.py @@ -117,69 +117,6 @@ def load_data(config, tokenizer): return train_dataloader, val_dataloader -def load_retrieval_augmented_data(config, tokenizer): - dataset_path = config["dataset_path"] - index_path = config['index_path'] - - #TODO this should precache at some point - index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) - index.load_index(index_path) - - if os.path.exists(dataset_path): - if os.path.isdir(dataset_path): - files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) - else: - files = [dataset_path] - - print(f"Reading files {files}") - - dataset = load_dataset("json", data_files=files, split="train") - - else: - dataset = load_dataset(dataset_path, split="train") - - dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) - - train_dataset, val_dataset = dataset["train"], dataset["test"] - - if config["streaming"] is False: - kwargs = {"num_proc": config["num_proc"]} - else: - kwargs = {} - - # tokenize inputs and return labels and attention mask - train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), - batched=True, - remove_columns=["source", "prompt"], - **kwargs - ) - val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), - batched=True, - remove_columns=["source", "prompt"], - **kwargs - ) - - train_dataset = train_dataset.with_format("torch") - val_dataset = val_dataset.with_format("torch") - - # create dataloader with default data collator since we already have labels - - train_dataloader = DataLoader( - train_dataset, - collate_fn=DefaultDataCollator(), - batch_size=config["batch_size"], - ) - - val_dataloader = DataLoader( - val_dataset, - collate_fn=DefaultDataCollator(), - batch_size=config["batch_size"], - ) - - return train_dataloader, val_dataloader - def load_data_for_inference(config, tokenizer): diff --git a/gpt4all/train/train.py b/gpt4all/train/train.py index 97b6c9a8..75d8eda3 100644 --- a/gpt4all/train/train.py +++ b/gpt4all/train/train.py @@ -1,5 +1,5 @@ import os -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler import torch from torch.optim import AdamW from argparse import ArgumentParser @@ -7,7 +7,7 @@ from gpt4all.utils.read import read_config from accelerate import Accelerator from accelerate.utils import DummyScheduler, DummyOptim, set_seed from peft import get_peft_model, LoraConfig, TaskType -from gpt4all.utils.data import load_data +from gpt4all.data.instruction_tuning_dataloader import load_data from torchmetrics import MeanMetric from tqdm import tqdm import wandb @@ -104,7 +104,7 @@ def train(accelerator, config): ) else: scheduler = DummyScheduler( - optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] ) model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(