mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-18 17:32:00 +00:00
fix: lr schedule
This commit is contained in:
parent
8a917ad4e1
commit
80d810322a
@ -117,69 +117,6 @@ def load_data(config, tokenizer):
|
|||||||
|
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
|
||||||
def load_retrieval_augmented_data(config, tokenizer):
|
|
||||||
dataset_path = config["dataset_path"]
|
|
||||||
index_path = config['index_path']
|
|
||||||
|
|
||||||
#TODO this should precache at some point
|
|
||||||
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
|
|
||||||
index.load_index(index_path)
|
|
||||||
|
|
||||||
if os.path.exists(dataset_path):
|
|
||||||
if os.path.isdir(dataset_path):
|
|
||||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
|
||||||
else:
|
|
||||||
files = [dataset_path]
|
|
||||||
|
|
||||||
print(f"Reading files {files}")
|
|
||||||
|
|
||||||
dataset = load_dataset("json", data_files=files, split="train")
|
|
||||||
|
|
||||||
else:
|
|
||||||
dataset = load_dataset(dataset_path, split="train")
|
|
||||||
|
|
||||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
|
||||||
|
|
||||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
|
||||||
|
|
||||||
if config["streaming"] is False:
|
|
||||||
kwargs = {"num_proc": config["num_proc"]}
|
|
||||||
else:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
# tokenize inputs and return labels and attention mask
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
|
||||||
batched=True,
|
|
||||||
remove_columns=["source", "prompt"],
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
val_dataset = val_dataset.map(
|
|
||||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
|
||||||
batched=True,
|
|
||||||
remove_columns=["source", "prompt"],
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = train_dataset.with_format("torch")
|
|
||||||
val_dataset = val_dataset.with_format("torch")
|
|
||||||
|
|
||||||
# create dataloader with default data collator since we already have labels
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
collate_fn=DefaultDataCollator(),
|
|
||||||
batch_size=config["batch_size"],
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
collate_fn=DefaultDataCollator(),
|
|
||||||
batch_size=config["batch_size"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_dataloader, val_dataloader
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_data_for_inference(config, tokenizer):
|
def load_data_for_inference(config, tokenizer):
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
@ -7,7 +7,7 @@ from gpt4all.utils.read import read_config
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||||
from peft import get_peft_model, LoraConfig, TaskType
|
from peft import get_peft_model, LoraConfig, TaskType
|
||||||
from gpt4all.utils.data import load_data
|
from gpt4all.data.instruction_tuning_dataloader import load_data
|
||||||
from torchmetrics import MeanMetric
|
from torchmetrics import MeanMetric
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
@ -104,7 +104,7 @@ def train(accelerator, config):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = DummyScheduler(
|
scheduler = DummyScheduler(
|
||||||
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
|
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
|
||||||
)
|
)
|
||||||
|
|
||||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||||
|
Loading…
Reference in New Issue
Block a user