mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,15 @@ prompt_seed = {
|
||||
},
|
||||
]
|
||||
}
|
||||
prompt_rlvr_seed = {
|
||||
"messages": [
|
||||
{
|
||||
"from": "user",
|
||||
"content": "What is the degree of the polynomial $(4 +5x^3 +100 +2\pi x^4 + \sqrt{10}x^4 +9)$?",
|
||||
},
|
||||
],
|
||||
"gt_answer": "4",
|
||||
}
|
||||
preference_seed = {
|
||||
"context": [
|
||||
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
|
||||
@@ -72,6 +81,8 @@ if __name__ == "__main__":
|
||||
seed = sft_seed
|
||||
elif args.data_type == "prompt":
|
||||
seed = prompt_seed
|
||||
elif args.data_type == "prompt_rlvr":
|
||||
seed = prompt_rlvr_seed
|
||||
elif args.data_type == "preference":
|
||||
seed = preference_seed
|
||||
elif args.data_type == "kto":
|
||||
|
16
applications/ColossalChat/tests/prepare_test_env.sh
Executable file
16
applications/ColossalChat/tests/prepare_test_env.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
# run under /ColossalAI/applications/ColossalChat
|
||||
export NCCL_SHM_DISABLE=1
|
||||
export MAX_JOBS=1
|
||||
export PRETRAINED_MODEL_PATH=./models
|
||||
export SFT_DATASET=./sft_data
|
||||
export PROMPT_DATASET=./prompt_data
|
||||
export PROMPT_RLVR_DATASET=./prompt_data
|
||||
export PREFERENCE_DATASET=./preference_data
|
||||
export KTO_DATASET=./kto_data
|
||||
mkdir models
|
||||
mkdir sft_data
|
||||
mkdir prompt_data
|
||||
mkdir preference_data
|
||||
mkdir kto_data
|
||||
# ./tests/test_data_preparation.sh
|
||||
# ./tests/test_train.sh
|
@@ -24,7 +24,12 @@ if [ -z "$SFT_DATASET" ]; then
|
||||
fi
|
||||
|
||||
if [ -z "$PROMPT_DATASET" ]; then
|
||||
echo "Please set \$PROMPT_DATASET to the path to prompts."
|
||||
echo "Please set \$PROMPT_DATASET to the path to prompts dataset."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PROMPT_RLVR_DATASET" ]; then
|
||||
echo "Please set \$PROMPT_RLVR_DATASET to the path to prompts dataset with gt_answer labels."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -69,6 +74,8 @@ get_data_input_dirs() {
|
||||
echo "$SFT_DATASET"
|
||||
elif [[ $data_type == "prompt" ]]; then
|
||||
echo "$PROMPT_DATASET"
|
||||
elif [[ $data_type == "prompt_rlvr" ]]; then
|
||||
echo "$PROMPT_RLVR_DATASET"
|
||||
elif [[ $data_type == "preference" ]]; then
|
||||
echo "$PREFERENCE_DATASET"
|
||||
elif [[ $data_type == "kto" ]]; then
|
||||
@@ -123,6 +130,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs prompt) \
|
||||
--data_type "prompt"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs prompt_rlvr) \
|
||||
--data_type "prompt_rlvr"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs kto) \
|
||||
--data_type "kto"
|
||||
@@ -266,6 +277,52 @@ for model in ${MODELS[@]}; do
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing prepare_prompt_dataset.py (with verifiable reward)..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_prompt_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="prompt_rlvr"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type prompt \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "[Test]: testing prepare_kto_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
|
@@ -81,8 +81,242 @@ random_choice() {
|
||||
echo ${arr[$idx]}
|
||||
}
|
||||
|
||||
echo "[Test]: testing grpo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
REWARD_FLAG=('nn' 'vr')
|
||||
for reward_type in ${REWARD_FLAG[@]}; do
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
rm_pretrain="--rm_pretrain $pretrain"
|
||||
reward_fn=""
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
rm_pretrain=""
|
||||
reward_fn="--reward_functions gsm8k_reward_fn"
|
||||
fi
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
ebs='1'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='2'
|
||||
ebs='1'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
|
||||
else
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
fi
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \
|
||||
--pretrain $pretrain \
|
||||
$rm_pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--num_generations 2 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 200 \ \
|
||||
--max_seq_len 10 \
|
||||
$reward_fn
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing ppo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
REWARD_FLAG=('vr' 'nn')
|
||||
for reward_type in ${REWARD_FLAG[@]}; do
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
reward_fn=""
|
||||
no_nn=""
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
reward_fn="--reward_functions gsm8k_reward_fn"
|
||||
no_nn="--no_neural_reward_model"
|
||||
fi
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
ebs='2'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='2'
|
||||
ebs='2'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
if [[ $reward_type == "vr" ]]; then
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
|
||||
else
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
fi
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
--pretrain $pretrain \
|
||||
--rm_pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--max_seq_len 10 \
|
||||
$reward_fn \
|
||||
$no_nn
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "[Test]: testing sft ..."
|
||||
|
||||
@@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing ppo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d # 3d plugin doesn't support lora
|
||||
llama-gemini # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='4'
|
||||
ebs='8'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
else
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='16'
|
||||
ebs='32'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
--pretrain $pretrain \
|
||||
--rm_pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--max_seq_len 10 \
|
||||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing DPO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
@@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
bs='2'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
@@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing ORPO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-0
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
@@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
bs='2'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
@@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing KTO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-0
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
@@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='2'
|
||||
bs='8'
|
||||
bs='2'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
lora_config=$LORA_CONFIG_ENABLE
|
||||
|
Reference in New Issue
Block a user