mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-17 15:37:07 +00:00
added initial files for dataset prep and ingestion for gpt4all jr
This commit is contained in:
parent
84acbc8225
commit
240feae277
41
configs/train/finetune_gptjr.yaml
Normal file
41
configs/train/finetune_gptjr.yaml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "nomic-ai/gpt4all-j"
|
||||||
|
tokenizer_name: "nomic-ai/gpt4all-j"
|
||||||
|
version: 'v1.2-jazzy'
|
||||||
|
gradient_checkpointing: true
|
||||||
|
save_name: # CHANGE
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "squad"
|
||||||
|
max_length: 1024
|
||||||
|
batch_size: 32
|
||||||
|
|
||||||
|
#index
|
||||||
|
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
|
||||||
|
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text"
|
||||||
|
index_space: "cosine"
|
||||||
|
index_dim: 384
|
||||||
|
query_embedding_field: 'question'
|
||||||
|
|
||||||
|
# train dynamics
|
||||||
|
lr: 2.0e-5
|
||||||
|
min_lr: 0
|
||||||
|
weight_decay: 0.0
|
||||||
|
eval_every: 500
|
||||||
|
eval_steps: 105
|
||||||
|
save_every: 500
|
||||||
|
log_grads_every: 100
|
||||||
|
output_dir: # CHANGE
|
||||||
|
checkpoint: null
|
||||||
|
lora: false
|
||||||
|
warmup_steps: 500
|
||||||
|
num_epochs: 2
|
||||||
|
|
||||||
|
# logging
|
||||||
|
wandb: false
|
||||||
|
wandb_entity: # CHANGE
|
||||||
|
wandb_project_name: # CHANGE
|
||||||
|
seed: 42
|
||||||
|
|
@ -5,6 +5,7 @@ from argparse import ArgumentParser
|
|||||||
import hnswlib
|
import hnswlib
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -42,6 +43,7 @@ def join(original_ds, embedded_ds):
|
|||||||
mask = pc.is_in(embed_table["index"], value_set=pa.array(indices, pa.int32()))
|
mask = pc.is_in(embed_table["index"], value_set=pa.array(indices, pa.int32()))
|
||||||
filtered_table = embed_table.filter(mask)
|
filtered_table = embed_table.filter(mask)
|
||||||
|
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
# sort to make sure we're adding in right order
|
# sort to make sure we're adding in right order
|
||||||
filtered_table = filtered_table.sort_by("id")
|
filtered_table = filtered_table.sort_by("id")
|
||||||
|
|
||||||
@ -60,6 +62,8 @@ def join(original_ds, embedded_ds):
|
|||||||
|
|
||||||
def build_index(orig_path, embed_folder_path, index_path):
|
def build_index(orig_path, embed_folder_path, index_path):
|
||||||
if not os.path.exists(orig_path + "_embedded_with_text"):
|
if not os.path.exists(orig_path + "_embedded_with_text"):
|
||||||
|
# TODO: this doesn't work for large datasets!
|
||||||
|
# just convert to pandas and do this manually
|
||||||
ds = Dataset.load_from_disk(orig_path)
|
ds = Dataset.load_from_disk(orig_path)
|
||||||
embed_ds = concat_embedded(embed_folder_path)
|
embed_ds = concat_embedded(embed_folder_path)
|
||||||
print("Concatenated embeddings")
|
print("Concatenated embeddings")
|
||||||
@ -79,7 +83,15 @@ def build_index(orig_path, embed_folder_path, index_path):
|
|||||||
# not sure what we should set M and ef_construction to
|
# not sure what we should set M and ef_construction to
|
||||||
index.init_index(max_elements=len(ds), M=64, ef_construction=200)
|
index.init_index(max_elements=len(ds), M=64, ef_construction=200)
|
||||||
print("Adding items")
|
print("Adding items")
|
||||||
index.add_items(ds["embedding"], ds["index"])
|
chunk_size = 50_000
|
||||||
|
num_chunks = len(ds) // chunk_size
|
||||||
|
progbar = tqdm(total=num_chunks)
|
||||||
|
start = 0
|
||||||
|
while start < len(ds):
|
||||||
|
chunk = ds[start:start + chunk_size]
|
||||||
|
index.add_items(chunk["embedding"], chunk["index"], num_threads=64)
|
||||||
|
progbar.update(1)
|
||||||
|
start += chunk_size
|
||||||
|
|
||||||
print("Saving index")
|
print("Saving index")
|
||||||
index.save_index(index_path + ".bin")
|
index.save_index(index_path + ".bin")
|
||||||
|
@ -45,16 +45,11 @@ class Embedder:
|
|||||||
|
|
||||||
return tokenized_text
|
return tokenized_text
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
return self.tokenizer(text, return_tensors="pt", truncation=True, padding="max_length")
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
if isinstance(batch, dict):
|
if isinstance(batch, str):
|
||||||
outputs = self.embedder(
|
|
||||||
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
|
|
||||||
)
|
|
||||||
embedding = self._mean_pool(outputs, batch["attention_mask"])
|
|
||||||
|
|
||||||
return {"id": batch["id"], "embedding": embedding}
|
|
||||||
|
|
||||||
elif isinstance(batch, str):
|
|
||||||
tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True)
|
tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True)
|
||||||
return self._mean_pool(
|
return self._mean_pool(
|
||||||
self.embedder(
|
self.embedder(
|
||||||
@ -63,6 +58,13 @@ class Embedder:
|
|||||||
),
|
),
|
||||||
tokenized["attention_mask"],
|
tokenized["attention_mask"],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
outputs = self.embedder(
|
||||||
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
|
||||||
|
)
|
||||||
|
embedding = self._mean_pool(outputs, batch["attention_mask"])
|
||||||
|
|
||||||
|
return {"id": batch["id"], "embedding": embedding}
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
self.embedder.to(device)
|
self.embedder.to(device)
|
||||||
|
@ -9,6 +9,7 @@ from transformers import BatchEncoding
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
class PadCollateInputs:
|
class PadCollateInputs:
|
||||||
@ -29,20 +30,29 @@ class PadCollateInputs:
|
|||||||
return padded_inputs
|
return padded_inputs
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(ds_path, batch_size):
|
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False):
|
||||||
rank0_print(f"World size: {dist.get_world_size()}")
|
rank0_print(f"World size: {dist.get_world_size()}")
|
||||||
|
dataset = load_dataset(f"{ds_path}", split="train")
|
||||||
dataset = Dataset.load_from_disk(ds_path)
|
|
||||||
rank0_print(f"Dataset size: {len(dataset)}")
|
rank0_print(f"Dataset size: {len(dataset)}")
|
||||||
dataset = dataset.remove_columns(["url", "title", "text"])
|
|
||||||
|
model = Embedder()
|
||||||
|
|
||||||
|
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
|
||||||
|
|
||||||
|
columns_to_keep = ["input_ids", "attention_mask"]
|
||||||
|
#to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
|
||||||
|
print('cols: ', dataset.column_names)
|
||||||
|
#dataset = dataset.remove_columns(to_remove)
|
||||||
|
|
||||||
|
#dataset = Dataset.load_from_disk(ds_path)
|
||||||
|
#dataset = dataset.remove_columns(["url", "title", "text"])
|
||||||
dataset = dataset.with_format("torch")
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
num_processes = dist.get_world_size()
|
num_processes = dist.get_world_size()
|
||||||
local_rank = dist.get_rank()
|
local_rank = dist.get_rank()
|
||||||
|
|
||||||
model = Embedder()
|
|
||||||
|
|
||||||
collator = PadCollateInputs(model.tokenizer)
|
#collator = PadCollateInputs(model.tokenizer)
|
||||||
|
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
dataset,
|
dataset,
|
||||||
@ -53,7 +63,7 @@ def embed_texts(ds_path, batch_size):
|
|||||||
)
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
collate_fn=collator,
|
# collate_fn=collator,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
@ -77,7 +87,9 @@ def embed_texts(ds_path, batch_size):
|
|||||||
|
|
||||||
# feeling lazy, don't want to wait for all_gather to finish
|
# feeling lazy, don't want to wait for all_gather to finish
|
||||||
# will load and concat in a single process in another script
|
# will load and concat in a single process in another script
|
||||||
ds.save_to_disk(f"embedded/{ds_path}_embedded_rank_{local_rank}")
|
if save_to_disk:
|
||||||
|
ds.save_to_disk(f"{ds_path}_embedded/{ds_path}_embedded_rank_{local_rank}")
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
58
gpt4all/index/prep_index_for_train.py
Normal file
58
gpt4all/index/prep_index_for_train.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
import hnswlib
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset
|
||||||
|
import torch.distributed as dist
|
||||||
|
from datasets import load_dataset
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from gpt4all.utils.read import read_config
|
||||||
|
from gpt4all.index.embed_texts import embed_texts
|
||||||
|
|
||||||
|
CHUNK_SIZE = 1024
|
||||||
|
K = 5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = read_config(args.config)
|
||||||
|
|
||||||
|
#load index
|
||||||
|
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
|
||||||
|
index.load_index(config['index_path'])
|
||||||
|
|
||||||
|
#load query dataset
|
||||||
|
ds_path = config['dataset_path']
|
||||||
|
|
||||||
|
#load retrieval dataset
|
||||||
|
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||||
|
print(type(retrieval_dataset._data))
|
||||||
|
raise
|
||||||
|
|
||||||
|
#vectorize queries
|
||||||
|
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
|
||||||
|
if not os.path.exists(query_vector_path):
|
||||||
|
print('Embedding dataset...')
|
||||||
|
ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False)
|
||||||
|
ds.save_to_disk(query_vector_path)
|
||||||
|
else:
|
||||||
|
print('Found cached embedding dataset!')
|
||||||
|
ds = Dataset.load_from_disk(query_vector_path)
|
||||||
|
|
||||||
|
#search the index for each query
|
||||||
|
for chunk_start in range(0, len(ds), CHUNK_SIZE):
|
||||||
|
chunk_end = chunk_start + CHUNK_SIZE
|
||||||
|
chunk = ds[chunk_start:chunk_end]
|
||||||
|
query_vectors = np.array(chunk['embedding'])
|
||||||
|
print(query_vectors.shape)
|
||||||
|
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
||||||
|
raise
|
||||||
|
|
||||||
|
#get the embeddings for each of the neighbor ids
|
||||||
|
|
0
gpt4all/index/test_load_index.py
Normal file
0
gpt4all/index/test_load_index.py
Normal file
246
gpt4all/train/train_r.py
Normal file
246
gpt4all/train/train_r.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import os
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
|
||||||
|
import torch
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from gpt4all.utils.read import read_config
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||||
|
from peft import get_peft_model, LoraConfig, TaskType
|
||||||
|
from gpt4all.utils.data import load_data, load_retrieval_augmented_data
|
||||||
|
from torchmetrics import MeanMetric
|
||||||
|
from tqdm import tqdm
|
||||||
|
from gpt4all.models import GPTJRForCausalLM
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
def format_metrics(metrics, split, prefix=""):
|
||||||
|
log = f"[{split}]" + prefix
|
||||||
|
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
|
||||||
|
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, val_dataloader):
|
||||||
|
model.eval()
|
||||||
|
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(val_dataloader):
|
||||||
|
loss = model(**batch).loss
|
||||||
|
|
||||||
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
||||||
|
|
||||||
|
val_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train(accelerator, config):
|
||||||
|
set_seed(config['seed'])
|
||||||
|
|
||||||
|
accelerator.print(config)
|
||||||
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||||
|
# if no pad token, set it to eos
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
||||||
|
with accelerator.main_process_first():
|
||||||
|
|
||||||
|
if 'index_path' in config:
|
||||||
|
train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer)
|
||||||
|
else:
|
||||||
|
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = config["gradient_checkpointing"]
|
||||||
|
#ensures back compat with non retrieval models
|
||||||
|
if 'index_path' in config:
|
||||||
|
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
|
||||||
|
revision=config['version'],
|
||||||
|
use_cache=False if checkpoint else True,
|
||||||
|
trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||||
|
use_cache=False if checkpoint else True,
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
if config["lora"]:
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
# should R be configurable?
|
||||||
|
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
optimizer_cls = (
|
||||||
|
AdamW
|
||||||
|
if accelerator.state.deepspeed_plugin is None
|
||||||
|
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
else DummyOptim
|
||||||
|
)
|
||||||
|
|
||||||
|
# karpathy doesn't decay embeddding, maybe we should exclude
|
||||||
|
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
||||||
|
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
|
||||||
|
|
||||||
|
if accelerator.state.deepspeed_plugin is not None:
|
||||||
|
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
|
"gradient_accumulation_steps"
|
||||||
|
]
|
||||||
|
|
||||||
|
# decay to min_lr instead of 0
|
||||||
|
lr_ratio = config["min_lr"] / config["lr"]
|
||||||
|
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
|
||||||
|
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
|
||||||
|
# instead of decaying to zero, decay to ratio of min_lr / lr
|
||||||
|
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
|
||||||
|
accelerator.print(f"Total training steps: {total_num_steps}")
|
||||||
|
|
||||||
|
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
|
||||||
|
if (
|
||||||
|
accelerator.state.deepspeed_plugin is None
|
||||||
|
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
):
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
name="cosine",
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
|
||||||
|
num_training_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scheduler = DummyScheduler(
|
||||||
|
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, val_dataloader, scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup for saving training states in case preemption
|
||||||
|
accelerator.register_for_checkpointing(scheduler)
|
||||||
|
|
||||||
|
if config["checkpoint"]:
|
||||||
|
accelerator.load_state(config["checkpoint"])
|
||||||
|
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||||
|
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||||
|
training_difference = os.path.splitext(path)[0]
|
||||||
|
resume_step = int(training_difference.replace("step_", ""))
|
||||||
|
accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||||
|
accelerator.print(f"Resuming from step {resume_step}")
|
||||||
|
|
||||||
|
|
||||||
|
# log gradients
|
||||||
|
if accelerator.is_main_process and config["wandb"]:
|
||||||
|
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||||
|
|
||||||
|
for epoch in range(config["num_epochs"]):
|
||||||
|
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||||
|
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||||
|
model.train()
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
# gather loss before backprop in case of gradient accumulation
|
||||||
|
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||||
|
train_loss.update(loss_values["loss"])
|
||||||
|
|
||||||
|
loss = loss / gradient_accumulation_steps
|
||||||
|
accelerator.backward(loss)
|
||||||
|
# get gradient norm of all params
|
||||||
|
|
||||||
|
# log LR in case something weird happens
|
||||||
|
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
||||||
|
if config["wandb"]:
|
||||||
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
|
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
|
||||||
|
|
||||||
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
if step > 0 and step % config["save_every"] == 0:
|
||||||
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
|
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||||
|
|
||||||
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||||
|
val_loss = evaluate(model, val_dataloader)
|
||||||
|
|
||||||
|
log_train = {
|
||||||
|
"train_loss": train_loss.compute()
|
||||||
|
}
|
||||||
|
log_val = {
|
||||||
|
"val_loss": val_loss.compute()
|
||||||
|
}
|
||||||
|
|
||||||
|
if config["wandb"]:
|
||||||
|
curr_step = step + epoch * len(train_dataloader)
|
||||||
|
accelerator.log({**log_train, **log_val}, step=curr_step)
|
||||||
|
|
||||||
|
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
|
||||||
|
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
|
||||||
|
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
|
||||||
|
|
||||||
|
train_loss.reset()
|
||||||
|
|
||||||
|
accelerator.print(f"Epoch {epoch} finished")
|
||||||
|
accelerator.print(f"Pushing to HF hub")
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
try:
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
accelerator.print(e)
|
||||||
|
accelerator.print(f"Failed to push to hub")
|
||||||
|
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/epoch_{epoch}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/final",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# parse arguments by reading in a config
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = read_config(args.config)
|
||||||
|
|
||||||
|
if config["wandb"]:
|
||||||
|
accelerator = Accelerator(log_with="wandb")
|
||||||
|
accelerator.init_trackers(
|
||||||
|
project_name=config["wandb_project_name"],
|
||||||
|
config=config,
|
||||||
|
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
accelerator = Accelerator()
|
||||||
|
|
||||||
|
train(accelerator, config=config)
|
@ -2,6 +2,7 @@ import glob
|
|||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import os
|
import os
|
||||||
|
import hnswlib
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import DefaultDataCollator
|
from transformers import DefaultDataCollator
|
||||||
|
|
||||||
@ -116,6 +117,70 @@ def load_data(config, tokenizer):
|
|||||||
|
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
|
||||||
|
def load_retrieval_augmented_data(config, tokenizer):
|
||||||
|
dataset_path = config["dataset_path"]
|
||||||
|
index_path = config['index_path']
|
||||||
|
|
||||||
|
#TODO this should precache at some point
|
||||||
|
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
|
||||||
|
index.load_index(index_path)
|
||||||
|
|
||||||
|
if os.path.exists(dataset_path):
|
||||||
|
if os.path.isdir(dataset_path):
|
||||||
|
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||||
|
else:
|
||||||
|
files = [dataset_path]
|
||||||
|
|
||||||
|
print(f"Reading files {files}")
|
||||||
|
|
||||||
|
dataset = load_dataset("json", data_files=files, split="train")
|
||||||
|
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(dataset_path, split="train")
|
||||||
|
|
||||||
|
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||||
|
|
||||||
|
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||||
|
|
||||||
|
if config["streaming"] is False:
|
||||||
|
kwargs = {"num_proc": config["num_proc"]}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# tokenize inputs and return labels and attention mask
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||||
|
batched=True,
|
||||||
|
remove_columns=["source", "prompt"],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
val_dataset = val_dataset.map(
|
||||||
|
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||||
|
batched=True,
|
||||||
|
remove_columns=["source", "prompt"],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.with_format("torch")
|
||||||
|
val_dataset = val_dataset.with_format("torch")
|
||||||
|
|
||||||
|
# create dataloader with default data collator since we already have labels
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
collate_fn=DefaultDataCollator(),
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
collate_fn=DefaultDataCollator(),
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataloader, val_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_data_for_inference(config, tokenizer):
|
def load_data_for_inference(config, tokenizer):
|
||||||
dataset_path = config["dataset_path"]
|
dataset_path = config["dataset_path"]
|
||||||
|
Loading…
Reference in New Issue
Block a user