This commit is contained in:
YeAnbang
2024-07-18 07:54:11 +00:00
parent b3594d4d68
commit 09d5ffca1a
27 changed files with 1739 additions and 63 deletions

View File

@@ -40,7 +40,13 @@ import random
import time
from multiprocessing import cpu_count
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
from coati.dataset import (
setup_conversation_template,
supervised_tokenize_sft,
tokenize_kto,
tokenize_prompt_dataset,
tokenize_rlhf,
)
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer
@@ -56,8 +62,8 @@ def main():
type=str,
required=True,
default=None,
choices=["sft", "prompt", "preference"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
choices=["sft", "prompt", "preference", "kto"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
)
parser.add_argument(
"--data_input_dirs",
@@ -204,6 +210,8 @@ def main():
preparation_function = tokenize_prompt_dataset
elif args.type == "preference":
preparation_function = tokenize_rlhf
elif args.type == "kto":
preparation_function = tokenize_kto
else:
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
@@ -228,10 +236,13 @@ def main():
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
dataset = dataset.filter(
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
)
if args.type == "kto":
filter_by = "completion"
elif args.type == "preference":
filter_by = "chosen_input_ids"
else:
filter_by = "input_ids"
dataset = dataset.filter(lambda data: data[filter_by] is not None)
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)

View File

@@ -0,0 +1,14 @@
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type kto \
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/kto_format/data \
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@@ -1,13 +1,13 @@
SAVE_DIR=""
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft"
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/sft \
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \