mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
add kto
This commit is contained in:
@@ -37,6 +37,19 @@ preference_seed = {
|
||||
"chosen": [{"from": "assistant", "content": "You can read?"}],
|
||||
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
|
||||
}
|
||||
kto_seed = {
|
||||
"prompt": [
|
||||
{"from": "human", "content": "What are some cuss words in english?"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama's boy, faggot, pervert, queer, scumbag, bitch,",
|
||||
},
|
||||
{"from": "human", "content": "What's your favorite one?"},
|
||||
],
|
||||
"completion": {"from": "assistant", "content": "Ass."},
|
||||
"label": False,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -61,12 +74,21 @@ if __name__ == "__main__":
|
||||
seed = prompt_seed
|
||||
elif args.data_type == "preference":
|
||||
seed = preference_seed
|
||||
elif args.data_type == "kto":
|
||||
seed = kto_seed
|
||||
else:
|
||||
raise ValueError(f"Unknown data type {args.data_type}")
|
||||
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
if args.data_type != "kto":
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
f.write(line)
|
||||
f.write(line)
|
||||
f.write(line)
|
||||
else:
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
seed["label"] = not seed["label"]
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
f.write(line)
|
||||
|
@@ -71,6 +71,8 @@ get_data_input_dirs() {
|
||||
echo "$PROMPT_DATASET"
|
||||
elif [[ $data_type == "preference" ]]; then
|
||||
echo "$PREFERENCE_DATASET"
|
||||
elif [[ $data_type == "kto" ]]; then
|
||||
echo "$KTO_DATASET"
|
||||
else
|
||||
echo "Unknown data type $data_type"
|
||||
exit 1
|
||||
@@ -121,6 +123,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs prompt) \
|
||||
--data_type "prompt"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs kto) \
|
||||
--data_type "kto"
|
||||
|
||||
echo "[Test]: testing prepare_preference_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
@@ -258,3 +264,50 @@ for model in ${MODELS[@]}; do
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing prepare_kto_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_kto_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="kto"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type kto \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
@@ -193,8 +193,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
@@ -264,8 +264,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
@@ -363,8 +363,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
@@ -440,8 +440,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
@@ -518,8 +518,87 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing KTO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto doesn't support generation
|
||||
# (need to calculate ref_model logits through forwarding in inference mode)
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
--desirable_weight 1.2 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
Reference in New Issue
Block a user