mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-08 21:05:53 +00:00
feat: data preprocessing
This commit is contained in:
parent
c9dd9152c3
commit
0c0a56acab
51
gpt4all/data/preprocess.py
Normal file
51
gpt4all/data/preprocess.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
|
||||
def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
|
||||
max_length = config["max_length"]
|
||||
|
||||
# hacky backward compatible
|
||||
different_eos = tokenizer.eos_token != "</s>"
|
||||
out = {"labels": [], "input_ids": []}
|
||||
for prompt, response in zip(examples[input_col], examples[target_col]):
|
||||
if different_eos:
|
||||
if response.count("</s> \n") > 0:
|
||||
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
|
||||
|
||||
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
|
||||
|
||||
# hack if our prompt is super long
|
||||
# we need to include some labels so we arbitrarily trunacate at max_length // 2
|
||||
# if the length is too long
|
||||
if prompt_len >= max_length // 2:
|
||||
# if prompt is too long, truncate
|
||||
# but make sure to truncate to at max 1024 tokens
|
||||
new_len = min(max_length // 2, len(prompt) // 2)
|
||||
prompt = prompt[:new_len]
|
||||
# get new prompt length
|
||||
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||
|
||||
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
|
||||
|
||||
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||
|
||||
labels = input_tokens.clone()
|
||||
labels[:prompt_len] = -100
|
||||
if len(labels) < max_length:
|
||||
# pad to max_length with -100
|
||||
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
||||
|
||||
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
|
||||
|
||||
if (labels == -100).sum() == len(labels) - 1:
|
||||
print(prompt)
|
||||
print(response)
|
||||
raise
|
||||
|
||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
||||
out["labels"].append(labels)
|
||||
out["input_ids"].append(input_tokens)
|
||||
|
||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||
|
||||
return out
|
75
gpt4all/data/retrieval_dataloader.py
Normal file
75
gpt4all/data/retrieval_dataloader.py
Normal file
@ -0,0 +1,75 @@
|
||||
from datasets import load_dataset, Dataset
|
||||
import os
|
||||
from torch.utils.data import DataLoader
|
||||
from .preprocess import tokenize_inputs
|
||||
from transformers import DefaultDataCollator
|
||||
|
||||
|
||||
def load_retrieval_augmented_data(config, tokenizer, split="train", split_dataset=True):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
dataset = Dataset.load_from_disk(dataset_path)
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split=split)
|
||||
|
||||
|
||||
question_col = config["q_column"]
|
||||
answer_col = config["a_column"]
|
||||
encoder_column = config["encoder_column"]
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
# strip any unneccessary whitespace
|
||||
# there's one question that's includes a ton of whitespace
|
||||
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
|
||||
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
|
||||
# a list of the answer
|
||||
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
|
||||
batched=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# tokenize inputs + labels in teacher-force format
|
||||
# rename encoder hidden states if not already called that
|
||||
if encoder_column != "encoder_hidden_states":
|
||||
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
|
||||
|
||||
columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"]
|
||||
|
||||
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
|
||||
dataset = dataset.remove_columns(col_names_to_rm)
|
||||
|
||||
if split_dataset:
|
||||
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
||||
else:
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config["batch_size"],
|
||||
collate_fn=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
Loading…
Reference in New Issue
Block a user