mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
add kto
This commit is contained in:
@@ -40,7 +40,13 @@ import random
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from coati.dataset import (
|
||||
setup_conversation_template,
|
||||
supervised_tokenize_sft,
|
||||
tokenize_kto,
|
||||
tokenize_prompt_dataset,
|
||||
tokenize_rlhf,
|
||||
)
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -56,8 +62,8 @@ def main():
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
choices=["sft", "prompt", "preference"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
|
||||
choices=["sft", "prompt", "preference", "kto"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_input_dirs",
|
||||
@@ -204,6 +210,8 @@ def main():
|
||||
preparation_function = tokenize_prompt_dataset
|
||||
elif args.type == "preference":
|
||||
preparation_function = tokenize_rlhf
|
||||
elif args.type == "kto":
|
||||
preparation_function = tokenize_kto
|
||||
else:
|
||||
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
|
||||
|
||||
@@ -228,10 +236,13 @@ def main():
|
||||
keep_in_memory=False,
|
||||
num_proc=min(len(dataset), cpu_count()),
|
||||
)
|
||||
|
||||
dataset = dataset.filter(
|
||||
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
|
||||
)
|
||||
if args.type == "kto":
|
||||
filter_by = "completion"
|
||||
elif args.type == "preference":
|
||||
filter_by = "chosen_input_ids"
|
||||
else:
|
||||
filter_by = "input_ids"
|
||||
dataset = dataset.filter(lambda data: data[filter_by] is not None)
|
||||
|
||||
# Save each jsonl spliced dataset.
|
||||
output_index = "0" * (5 - len(str(index))) + str(index)
|
||||
|
@@ -0,0 +1,14 @@
|
||||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type kto \
|
||||
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/kto_format/data \
|
||||
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 1024
|
@@ -1,13 +1,13 @@
|
||||
SAVE_DIR=""
|
||||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/sft \
|
||||
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
|
Reference in New Issue
Block a user