Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-02-18 09:43:36 +08:00
committed by GitHub
parent ce0ec40811
commit d20c8ffd97
39 changed files with 1995 additions and 277 deletions

View File

@@ -20,6 +20,15 @@ prompt_seed = {
},
]
}
prompt_rlvr_seed = {
"messages": [
{
"from": "user",
"content": "What is the degree of the polynomial $(4 +5x^3 +100 +2\pi x^4 + \sqrt{10}x^4 +9)$?",
},
],
"gt_answer": "4",
}
preference_seed = {
"context": [
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
@@ -72,6 +81,8 @@ if __name__ == "__main__":
seed = sft_seed
elif args.data_type == "prompt":
seed = prompt_seed
elif args.data_type == "prompt_rlvr":
seed = prompt_rlvr_seed
elif args.data_type == "preference":
seed = preference_seed
elif args.data_type == "kto":

View File

@@ -0,0 +1,16 @@
# run under /ColossalAI/applications/ColossalChat
export NCCL_SHM_DISABLE=1
export MAX_JOBS=1
export PRETRAINED_MODEL_PATH=./models
export SFT_DATASET=./sft_data
export PROMPT_DATASET=./prompt_data
export PROMPT_RLVR_DATASET=./prompt_data
export PREFERENCE_DATASET=./preference_data
export KTO_DATASET=./kto_data
mkdir models
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
# ./tests/test_data_preparation.sh
# ./tests/test_train.sh

View File

@@ -24,7 +24,12 @@ if [ -z "$SFT_DATASET" ]; then
fi
if [ -z "$PROMPT_DATASET" ]; then
echo "Please set \$PROMPT_DATASET to the path to prompts."
echo "Please set \$PROMPT_DATASET to the path to prompts dataset."
exit 1
fi
if [ -z "$PROMPT_RLVR_DATASET" ]; then
echo "Please set \$PROMPT_RLVR_DATASET to the path to prompts dataset with gt_answer labels."
exit 1
fi
@@ -69,6 +74,8 @@ get_data_input_dirs() {
echo "$SFT_DATASET"
elif [[ $data_type == "prompt" ]]; then
echo "$PROMPT_DATASET"
elif [[ $data_type == "prompt_rlvr" ]]; then
echo "$PROMPT_RLVR_DATASET"
elif [[ $data_type == "preference" ]]; then
echo "$PREFERENCE_DATASET"
elif [[ $data_type == "kto" ]]; then
@@ -123,6 +130,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt) \
--data_type "prompt"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt_rlvr) \
--data_type "prompt_rlvr"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs kto) \
--data_type "kto"
@@ -266,6 +277,52 @@ for model in ${MODELS[@]}; do
done
echo "[Test]: testing prepare_prompt_dataset.py (with verifiable reward)..."
# FIXME: This is a hack to skip tests that are not working
SKIPPED_TESTS=(
)
# test prepare_prompt_dataset
for model in ${MODELS[@]}; do
data_type="prompt_rlvr"
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
echo "[Test]: Skipped $model-$data_type"
continue
fi
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
data_input_dirs=$(get_data_input_dirs $data_type)
tokenizer_dir=$(get_tokenizer_dirs $model)
conversation_template=$(get_conversation_template_config $model)
for i in $(seq $NUM_RETRY); do
rm -rf $cache_dir
rm -rf $jsonl_dir
rm -rf $arrow_dir
echo "[Test]: $model-$data_type, attempt $i"
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
--type prompt \
--data_input_dirs $data_input_dirs \
--conversation_template_config $conversation_template \
--tokenizer_dir $tokenizer_dir \
--data_cache_dir $cache_dir \
--data_jsonl_output_dir $jsonl_dir \
--data_arrow_output_dir $arrow_dir \
--max_length 400 \
--num_samples_per_datafile 100 \
--num_spliced_dataset_bins 1
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$data_type"
exit 1
fi
done
echo "[Test]: testing prepare_kto_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working

View File

@@ -81,8 +81,242 @@ random_choice() {
echo ${arr[$idx]}
}
echo "[Test]: testing grpo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
REWARD_FLAG=('nn' 'vr')
for reward_type in ${REWARD_FLAG[@]}; do
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
rm_pretrain="--rm_pretrain $pretrain"
reward_fn=""
if [[ $reward_type == "vr" ]]; then
rm_pretrain=""
reward_fn="--reward_functions gsm8k_reward_fn"
fi
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
ebs='1'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='2'
ebs='1'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
if [[ $reward_type == "vr" ]]; then
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
else
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
fi
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \
--pretrain $pretrain \
$rm_pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--num_generations 2 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 200 \ \
--max_seq_len 10 \
$reward_fn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
exit 1
fi
done
done
done
done
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
REWARD_FLAG=('vr' 'nn')
for reward_type in ${REWARD_FLAG[@]}; do
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
reward_fn=""
no_nn=""
if [[ $reward_type == "vr" ]]; then
reward_fn="--reward_functions gsm8k_reward_fn"
no_nn="--no_neural_reward_model"
fi
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
ebs='2'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='2'
ebs='2'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
if [[ $reward_type == "vr" ]]; then
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
else
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
fi
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
$reward_fn \
$no_nn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
exit 1
fi
done
done
done
done
echo "[Test]: testing sft ..."
@@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='4'
ebs='8'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='16'
ebs='32'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing DPO ..."
SKIPPED_TESTS=(
@@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
@@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do
done
echo "[Test]: testing ORPO ..."
SKIPPED_TESTS=(
llama-3d-0
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
@@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
@@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
echo "[Test]: testing KTO ..."
SKIPPED_TESTS=(
llama-3d-0
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
@@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE