mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 03:47:57 +00:00
add eval
This commit is contained in:
parent
fbcb0149bd
commit
3ab2f6fa32
@ -91,7 +91,7 @@ SKIPPED_TESTS=(
|
|||||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||||
llama-gemini-20 # gemini doesn't support lora
|
llama-gemini-20 # gemini doesn't support lora
|
||||||
)
|
)
|
||||||
skip_eval=false
|
|
||||||
GRAD_CKPTS=('--grad_checkpoint')
|
GRAD_CKPTS=('--grad_checkpoint')
|
||||||
for lora_rank in ${LORA_RANK[@]}; do
|
for lora_rank in ${LORA_RANK[@]}; do
|
||||||
for model in ${MODELS[@]}; do
|
for model in ${MODELS[@]}; do
|
||||||
@ -134,13 +134,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
bs='8'
|
bs='8'
|
||||||
pp='2'
|
pp='2'
|
||||||
plugin='3d'
|
plugin='3d'
|
||||||
skip_eval=true
|
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "pp" ]]; then
|
if [[ $plugin == "pp" ]]; then
|
||||||
bs='8'
|
bs='8'
|
||||||
pp='2'
|
pp='2'
|
||||||
plugin='3d'
|
plugin='3d'
|
||||||
skip_eval=true
|
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "sp_split_gather" ]]; then
|
if [[ $plugin == "sp_split_gather" ]]; then
|
||||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||||
@ -178,53 +176,29 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
for split in $(seq -f "%05g" 0 0); do
|
for split in $(seq -f "%05g" 0 0); do
|
||||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||||
done
|
done
|
||||||
|
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||||
if [[ $skip_eval ]]; then
|
--pretrain $pretrain \
|
||||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
--tokenizer_dir $tokenizer_dir \
|
||||||
--pretrain $pretrain \
|
--dataset ${dataset[@]} \
|
||||||
--tokenizer_dir $tokenizer_dir \
|
--eval_dataset ${dataset[@]} \
|
||||||
--dataset ${dataset[@]} \
|
--save_path $MODEL_SAVE_PATH \
|
||||||
--save_path $MODEL_SAVE_PATH \
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
$lora_config \
|
||||||
$lora_config \
|
--plugin $plugin \
|
||||||
--plugin $plugin \
|
--batch_size $bs \
|
||||||
--batch_size $bs \
|
--max_epochs 1 \
|
||||||
--max_epochs 1 \
|
--accumulation_steps $grad_accu \
|
||||||
--accumulation_steps $grad_accu \
|
--tp $tp \
|
||||||
--tp $tp \
|
--pp $pp \
|
||||||
--pp $pp \
|
--zero_stage $zero_stage \
|
||||||
--zero_stage $zero_stage \
|
--sp $sp \
|
||||||
--sp $sp \
|
--sp_mode $sp_mode \
|
||||||
--sp_mode $sp_mode \
|
$enable_sequence_parallelism \
|
||||||
$enable_sequence_parallelism \
|
--lr 2e-5 \
|
||||||
--lr 2e-5 \
|
$grad_ckpt \
|
||||||
$grad_ckpt \
|
--max_len 400 \
|
||||||
--max_len 400 \
|
--use_flash_attn
|
||||||
--use_flash_attn
|
# fi
|
||||||
else
|
|
||||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
|
||||||
--pretrain $pretrain \
|
|
||||||
--tokenizer_dir $tokenizer_dir \
|
|
||||||
--dataset ${dataset[@]} \
|
|
||||||
--eval_dataset ${dataset[@]} \
|
|
||||||
--save_path $MODEL_SAVE_PATH \
|
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
|
||||||
$lora_config \
|
|
||||||
--plugin $plugin \
|
|
||||||
--batch_size $bs \
|
|
||||||
--max_epochs 1 \
|
|
||||||
--accumulation_steps $grad_accu \
|
|
||||||
--tp $tp \
|
|
||||||
--pp $pp \
|
|
||||||
--zero_stage $zero_stage \
|
|
||||||
--sp $sp \
|
|
||||||
--sp_mode $sp_mode \
|
|
||||||
$enable_sequence_parallelism \
|
|
||||||
--lr 2e-5 \
|
|
||||||
$grad_ckpt \
|
|
||||||
--max_len 400 \
|
|
||||||
--use_flash_attn
|
|
||||||
fi
|
|
||||||
passed=$?
|
passed=$?
|
||||||
if [ $passed -eq 0 ]; then
|
if [ $passed -eq 0 ]; then
|
||||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||||
|
Loading…
Reference in New Issue
Block a user