mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa32
.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -61,7 +61,7 @@ def test_overfit():
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total = labels.size(0)
|
||||
correct = (predicted == Y).sum().item()
|
||||
assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset")
|
||||
assert correct / total > 0.95
|
||||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||
|
||||
|
||||
|
@@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xu
|
||||
|
||||
@@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
||||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
|
||||
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp')
|
||||
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
|
||||
@@ -91,7 +91,7 @@ SKIPPED_TESTS=(
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
|
||||
skip_eval=false
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
@@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "tp_pp" ]]; then
|
||||
echo "Here"
|
||||
tp='2'
|
||||
bs='8'
|
||||
pp='2'
|
||||
plugin='3d'
|
||||
skip_eval=true
|
||||
fi
|
||||
if [[ $plugin == "pp" ]]; then
|
||||
bs='8'
|
||||
pp='2'
|
||||
plugin='3d'
|
||||
skip_eval=true
|
||||
fi
|
||||
if [[ $plugin == "sp_split_gather" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
@@ -175,28 +178,53 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
|
||||
if [[ $skip_eval ]]; then
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
else
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
fi
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
|
Reference in New Issue
Block a user