mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
upgrade ppo dpo rm script
This commit is contained in:
@@ -4,5 +4,6 @@
|
||||
"stop_ids": [
|
||||
29871,
|
||||
2
|
||||
]
|
||||
}
|
||||
],
|
||||
"end_of_assistant": "</s>"
|
||||
}
|
@@ -9,14 +9,13 @@ model_data_mapping = {
|
||||
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
|
||||
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
|
||||
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
|
||||
'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json',
|
||||
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
|
||||
'01-ai/Yi-34B': '01-ai_Yi-34B.json',
|
||||
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
|
||||
'microsoft/phi-2': 'microsoft_phi-2.json',
|
||||
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
|
||||
}
|
||||
chat_template_config_path = '../config/conversation_template'
|
||||
chat_template_config_path = './config/conversation_template'
|
||||
|
||||
|
||||
def test_tokenization_sft():
|
||||
@@ -34,5 +33,5 @@ def test_tokenization_sft():
|
||||
)
|
||||
|
||||
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
|
||||
with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
|
||||
with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
|
||||
assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed"
|
||||
|
@@ -6,35 +6,59 @@ TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
||||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
|
||||
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
|
||||
|
||||
MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan")
|
||||
#
|
||||
get_pretrain() {
|
||||
local model=$1
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "hf-internal-testing/llama-tokenizer"
|
||||
elif [[ $model == "zephyr" ]]; then
|
||||
echo "HuggingFaceH4/zephyr-7b-beta"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "microsoft/phi-2"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "THUDM/chatglm2-6b"
|
||||
elif [[ $model == "Qwen" ]]; then
|
||||
echo "Qwen/Qwen-7B-Chat"
|
||||
elif [[ $model == "Vicuna" ]]; then
|
||||
echo "lmsys/vicuna-7b-v1.5"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "THUDM/chatglm3-6b"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "deepseek-ai/DeepSeek-V2-Lite"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "01-ai/Yi-6B-Chat"
|
||||
echo "01-ai/Yi-1.5-9B-Chat"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "baichuan-inc/Baichuan2-13B-Chat"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
echo "$CONFIG_DIR/conversation_template/$model.json"
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/llama2.json"
|
||||
elif [[ $model == "deepseek" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
|
||||
elif [[ $model == "chatGLM3" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
|
||||
elif [[ $model == "phi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
|
||||
elif [[ $model == "baichuan" ]]; then
|
||||
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Test SFT data Preparation
|
||||
|
@@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
||||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
|
||||
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
Reference in New Issue
Block a user