feat: pull from multiple datasets

This commit is contained in:
Zach Nussbaum 2023-04-17 20:00:19 +00:00
parent 0b4d45e57d
commit c76f6e33a9

20
data.py
View File

@ -61,7 +61,23 @@ def tokenize_inputs(config, tokenizer, examples):
def load_data(config, tokenizer):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
if isinstance(dataset_path, list):
all_datasets = []
for path in dataset_path:
dataset = load_dataset(path, split="train")
current_columns = dataset.column_names
columns_to_keep = ["prompt", "response"]
to_remove = set(current_columns) - set(columns_to_keep)
dataset = dataset.remove_columns(to_remove)
if "source" not in current_columns:
dataset = dataset.add_column("source", [path.split("/")[-1]] * len(dataset))
all_datasets.append(dataset)
dataset = concatenate_datasets(all_datasets)
elif os.path.exists(dataset_path):
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
@ -70,7 +86,7 @@ def load_data(config, tokenizer):
print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train")
else:
dataset = load_dataset(dataset_path, split="train")