[Coati] Train DPO using PP (#6054)

* update dpo

* remove unsupport plugin

* update msg

* update dpo

* remove unsupport plugin

* update msg

* update template

* update dataset

* add pp for dpo

* update dpo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add dpo fn

* update dpo

* update dpo

* update dpo

* update dpo

* minor update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update loss

* update help

* polish code

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2024-10-11 19:32:00 +08:00
committed by GitHub
parent dc2cdaf3e8
commit 4c8e85ee0d
8 changed files with 529 additions and 236 deletions

View File

@@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
EXAMPLES_DIR=$BASE_DIR/examples
TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config
CONFIG_DIR=$BASE_DIR/conversation_template
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
@@ -39,23 +39,23 @@ get_pretrain() {
get_conversation_template_config() {
local model=$1
if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
echo "$CONFIG_DIR/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json"
echo "$CONFIG_DIR/llama2.json"
elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
echo "$CONFIG_DIR/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
else
echo "Unknown model $model"
exit 1
@@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
echo $conversation_template_config
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \