replace the customized dataloader setup with the build-in one

This commit is contained in:
YeAnbang
2024-06-07 09:43:42 +00:00
parent 790e1362a6
commit 0d7ff10ea5
12 changed files with 79 additions and 218 deletions

View File

@@ -30,8 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
# ADVANCED_PLUGINS=('pp' 'tp_zero2' 'tp_pp' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
ADVANCED_PLUGINS=('tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
@@ -281,7 +280,7 @@ echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
# llama-3d # 3d plugin doesn't support lora
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
@@ -359,7 +358,7 @@ for lora_rank in ${LORA_RANK[@]}; do
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
--use_flash_attn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*