diff --git a/data.py b/data.py index 8227de00..915a4dea 100644 --- a/data.py +++ b/data.py @@ -61,7 +61,23 @@ def tokenize_inputs(config, tokenizer, examples): def load_data(config, tokenizer): dataset_path = config["dataset_path"] - if os.path.exists(dataset_path): + if isinstance(dataset_path, list): + all_datasets = [] + for path in dataset_path: + dataset = load_dataset(path, split="train") + + current_columns = dataset.column_names + columns_to_keep = ["prompt", "response"] + to_remove = set(current_columns) - set(columns_to_keep) + + dataset = dataset.remove_columns(to_remove) + if "source" not in current_columns: + dataset = dataset.add_column("source", [path.split("/")[-1]] * len(dataset)) + all_datasets.append(dataset) + + dataset = concatenate_datasets(all_datasets) + + elif os.path.exists(dataset_path): if os.path.isdir(dataset_path): files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) else: @@ -70,7 +86,7 @@ def load_data(config, tokenizer): print(f"Reading files {files}") dataset = load_dataset("json", data_files=files, split="train") - + else: dataset = load_dataset(dataset_path, split="train")