upgrade ppo dpo rm script

This commit is contained in:
YeAnbang
2024-05-28 03:04:39 +00:00
parent 7a7e86987d
commit 929e1e3da4
15 changed files with 169 additions and 139 deletions

View File

@@ -4,5 +4,6 @@
"stop_ids": [
29871,
2
]
}
],
"end_of_assistant": "</s>"
}

View File

@@ -9,14 +9,13 @@ model_data_mapping = {
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json',
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
'01-ai/Yi-34B': '01-ai_Yi-34B.json',
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
'microsoft/phi-2': 'microsoft_phi-2.json',
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
}
chat_template_config_path = '../config/conversation_template'
chat_template_config_path = './config/conversation_template'
def test_tokenization_sft():
@@ -34,5 +33,5 @@ def test_tokenization_sft():
)
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed"

View File

@@ -6,35 +6,59 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan")
#
get_pretrain() {
local model=$1
if [[ $model == "colossal-llama2" ]]; then
echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
elif [[ $model == "llama2" ]]; then
echo "hf-internal-testing/llama-tokenizer"
elif [[ $model == "zephyr" ]]; then
echo "HuggingFaceH4/zephyr-7b-beta"
elif [[ $model == "phi" ]]; then
echo "microsoft/phi-2"
elif [[ $model == "mistral" ]]; then
echo "mistralai/Mistral-7B-Instruct-v0.2"
echo "mistralai/Mistral-7B-Instruct-v0.3"
elif [[ $model == "chatGLM2" ]]; then
echo "THUDM/chatglm2-6b"
elif [[ $model == "Qwen" ]]; then
echo "Qwen/Qwen-7B-Chat"
elif [[ $model == "Vicuna" ]]; then
echo "lmsys/vicuna-7b-v1.5"
elif [[ $model == "chatGLM3" ]]; then
echo "THUDM/chatglm3-6b"
elif [[ $model == "deepseek" ]]; then
echo "deepseek-ai/DeepSeek-V2-Lite"
elif [[ $model == "Yi" ]]; then
echo "01-ai/Yi-6B-Chat"
echo "01-ai/Yi-1.5-9B-Chat"
elif [[ $model == "baichuan" ]]; then
echo "baichuan-inc/Baichuan2-13B-Chat"
else
echo "Unknown model $model"
exit 1
fi
}
get_conversation_template_config() {
local model=$1
echo "$CONFIG_DIR/conversation_template/$model.json"
if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json"
elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
else
echo "Unknown model $model"
exit 1
fi
}
# Test SFT data Preparation

View File

@@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
export OMP_NUM_THREADS=8