From 6518fa1461b2c3fe824f4a58f26704f58533f7e7 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Wed, 19 Apr 2023 18:40:58 +0000 Subject: [PATCH] feat: load dataset from revision --- configs/train/finetune.yaml | 1 + configs/train/finetune_gptj.yaml | 1 + configs/train/finetune_gptj_lora.yaml | 1 + configs/train/finetune_lora.yaml | 1 + data.py | 5 ++++- 5 files changed, 8 insertions(+), 1 deletion(-) diff --git a/configs/train/finetune.yaml b/configs/train/finetune.yaml index ac9d2129..3816ba78 100644 --- a/configs/train/finetune.yaml +++ b/configs/train/finetune.yaml @@ -8,6 +8,7 @@ save_name: # CHANGE streaming: false num_proc: 64 dataset_path: # update +revision: null max_length: 1024 batch_size: 32 diff --git a/configs/train/finetune_gptj.yaml b/configs/train/finetune_gptj.yaml index ef9802d6..265e0031 100644 --- a/configs/train/finetune_gptj.yaml +++ b/configs/train/finetune_gptj.yaml @@ -8,6 +8,7 @@ save_name: # CHANGE streaming: false num_proc: 64 dataset_path: # CHANGE +revision: null max_length: 1024 batch_size: 32 diff --git a/configs/train/finetune_gptj_lora.yaml b/configs/train/finetune_gptj_lora.yaml index c2668ddd..9513fe0b 100644 --- a/configs/train/finetune_gptj_lora.yaml +++ b/configs/train/finetune_gptj_lora.yaml @@ -8,6 +8,7 @@ save_name: # CHANGE streaming: false num_proc: 64 dataset_path: # CHANGE +revision: null max_length: 1024 batch_size: 1 diff --git a/configs/train/finetune_lora.yaml b/configs/train/finetune_lora.yaml index e316fb11..69612e4a 100644 --- a/configs/train/finetune_lora.yaml +++ b/configs/train/finetune_lora.yaml @@ -8,6 +8,7 @@ save_name: # CHANGE streaming: false num_proc: 64 dataset_path: # CHANGE +revision: null max_length: 1024 batch_size: 4 diff --git a/data.py b/data.py index 915a4dea..a3beea24 100644 --- a/data.py +++ b/data.py @@ -77,6 +77,7 @@ def load_data(config, tokenizer): dataset = concatenate_datasets(all_datasets) + # load local json dataset elif os.path.exists(dataset_path): if os.path.isdir(dataset_path): files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) @@ -87,8 +88,10 @@ def load_data(config, tokenizer): dataset = load_dataset("json", data_files=files, split="train") + # read from huggingface else: - dataset = load_dataset(dataset_path, split="train") + revison = config["revision"] + dataset = load_dataset(dataset_path, split="train", revision=revision) dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])