diff --git a/gpt4all/data/preprocess.py b/gpt4all/data/preprocess.py new file mode 100644 index 00000000..6f5c2e26 --- /dev/null +++ b/gpt4all/data/preprocess.py @@ -0,0 +1,51 @@ +import torch + +def tokenize_inputs(config, tokenizer, examples, input_col, target_col): + max_length = config["max_length"] + + # hacky backward compatible + different_eos = tokenizer.eos_token != "" + out = {"labels": [], "input_ids": []} + for prompt, response in zip(examples[input_col], examples[target_col]): + if different_eos: + if response.count(" \n") > 0: + response = response.replace(" \n", f"{tokenizer.eos_token} \n") + + prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0]) + + # hack if our prompt is super long + # we need to include some labels so we arbitrarily trunacate at max_length // 2 + # if the length is too long + if prompt_len >= max_length // 2: + # if prompt is too long, truncate + # but make sure to truncate to at max 1024 tokens + new_len = min(max_length // 2, len(prompt) // 2) + prompt = prompt[:new_len] + # get new prompt length + prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() + + assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}" + + input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, + truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() + + labels = input_tokens.clone() + labels[:prompt_len] = -100 + if len(labels) < max_length: + # pad to max_length with -100 + labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) + + assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}" + + if (labels == -100).sum() == len(labels) - 1: + print(prompt) + print(response) + raise + + input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] + out["labels"].append(labels) + out["input_ids"].append(input_tokens) + + out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} + + return out \ No newline at end of file diff --git a/gpt4all/data/retrieval_dataloader.py b/gpt4all/data/retrieval_dataloader.py new file mode 100644 index 00000000..75c87574 --- /dev/null +++ b/gpt4all/data/retrieval_dataloader.py @@ -0,0 +1,75 @@ +from datasets import load_dataset, Dataset +import os +from torch.utils.data import DataLoader +from .preprocess import tokenize_inputs +from transformers import DefaultDataCollator + + +def load_retrieval_augmented_data(config, tokenizer, split="train", split_dataset=True): + dataset_path = config["dataset_path"] + + if os.path.exists(dataset_path): + dataset = Dataset.load_from_disk(dataset_path) + else: + dataset = load_dataset(dataset_path, split=split) + + + question_col = config["q_column"] + answer_col = config["a_column"] + encoder_column = config["encoder_column"] + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # strip any unneccessary whitespace + # there's one question that's includes a ton of whitespace + dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True) + # in squad, the data is formatted where each ele in answers is a dict where the key text holds + # a list of the answer + dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True) + + dataset = dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col), + batched=True, + **kwargs + ) + + # tokenize inputs + labels in teacher-force format + # rename encoder hidden states if not already called that + if encoder_column != "encoder_hidden_states": + dataset = dataset.rename_column(encoder_column, "encoder_hidden_states") + + columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"] + + col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] + dataset = dataset.remove_columns(col_names_to_rm) + + if split_dataset: + dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"]) + train_dataset, val_dataset = dataset["train"], dataset["test"] + + train_dataloader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return train_dataloader, val_dataloader + + else: + dataloader = DataLoader( + dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return dataloader +