mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -81,8 +81,242 @@ random_choice() {
|
||||
echo ${arr[$idx]}
|
||||
}
|
||||
|
||||
echo "[Test]: testing grpo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
REWARD_FLAG=('nn' 'vr')
|
||||
for reward_type in ${REWARD_FLAG[@]}; do
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
rm_pretrain="--rm_pretrain $pretrain"
|
||||
reward_fn=""
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
rm_pretrain=""
|
||||
reward_fn="--reward_functions gsm8k_reward_fn"
|
||||
fi
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
ebs='1'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='2'
|
||||
ebs='1'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
|
||||
else
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
fi
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \
|
||||
--pretrain $pretrain \
|
||||
$rm_pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--num_generations 2 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 200 \ \
|
||||
--max_seq_len 10 \
|
||||
$reward_fn
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing ppo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
REWARD_FLAG=('vr' 'nn')
|
||||
for reward_type in ${REWARD_FLAG[@]}; do
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
reward_fn=""
|
||||
no_nn=""
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
reward_fn="--reward_functions gsm8k_reward_fn"
|
||||
no_nn="--no_neural_reward_model"
|
||||
fi
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
ebs='2'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='2'
|
||||
ebs='2'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
|
||||
else
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
fi
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
--pretrain $pretrain \
|
||||
--rm_pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--max_seq_len 10 \
|
||||
$reward_fn \
|
||||
$no_nn
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "[Test]: testing sft ..."
|
||||
|
||||
@@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing ppo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='4'
|
||||
ebs='8'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='16'
|
||||
ebs='32'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
--pretrain $pretrain \
|
||||
--rm_pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--max_seq_len 10 \
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing DPO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
@@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
bs='2'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
@@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing ORPO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-0
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
@@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
bs='2'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
@@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing KTO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-0
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
@@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
bs='2'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
|
Reference in New Issue
Block a user