mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
0
applications/ColossalChat/tests/__init__.py
Executable file
0
applications/ColossalChat/tests/__init__.py
Executable file
@@ -0,0 +1,72 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
sft_seed = {
|
||||
"messages": [
|
||||
{"from": "human", "content": "Give three tips for staying healthy."},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
|
||||
},
|
||||
]
|
||||
}
|
||||
prompt_seed = {
|
||||
"messages": [
|
||||
{"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
|
||||
},
|
||||
]
|
||||
}
|
||||
preference_seed = {
|
||||
"context": [
|
||||
{"from": "human", "content": "What kind of noises did dinosaurs make?"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
|
||||
},
|
||||
{"from": "human", "content": "yes they did"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
|
||||
},
|
||||
{"from": "human", "content": "you cant read"},
|
||||
],
|
||||
"chosen": [{"from": "assistant", "content": "You can read?"}],
|
||||
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The output dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_type",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The type of data",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.data_type == "sft":
|
||||
seed = sft_seed
|
||||
elif args.data_type == "prompt":
|
||||
seed = prompt_seed
|
||||
elif args.data_type == "preference":
|
||||
seed = preference_seed
|
||||
else:
|
||||
raise ValueError(f"Unknown data type {args.data_type}")
|
||||
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
f.write(line)
|
||||
f.write(line)
|
8
applications/ColossalChat/tests/llama.json
Normal file
8
applications/ColossalChat/tests/llama.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' + bos_token }}{% endif %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"stop_ids": [
|
||||
29871,
|
||||
2
|
||||
]
|
||||
}
|
17
applications/ColossalChat/tests/opt.json
Normal file
17
applications/ColossalChat/tests/opt.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}",
|
||||
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
"human_line_start": [
|
||||
2
|
||||
],
|
||||
"human_line_end": [
|
||||
2
|
||||
],
|
||||
"assistant_line_start": [
|
||||
2
|
||||
],
|
||||
"assistant_line_end": [
|
||||
2
|
||||
],
|
||||
"end_of_system_line_position": 0
|
||||
}
|
@@ -0,0 +1 @@
|
||||
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
|
@@ -0,0 +1 @@
|
||||
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
|
260
applications/ColossalChat/tests/test_data_preparation.sh
Executable file
260
applications/ColossalChat/tests/test_data_preparation.sh
Executable file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xu
|
||||
|
||||
if [ -z "$SFT_DATASET" ]; then
|
||||
echo "Please set \$SFT_DATASET to the path to sft dataset."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PROMPT_DATASET" ]; then
|
||||
echo "Please set \$PROMPT_DATASET to the path to prompts."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$PREFERENCE_DATASET" ]; then
|
||||
echo "Please set \$SFT_DATASET to the path to sft dataset."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NUM_RETRY=3
|
||||
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
|
||||
BASE_TEMP_DIR=$BASE_DIR/temp
|
||||
TEST_DIR=$BASE_DIR/tests
|
||||
EXAMPLES_DIR=$BASE_DIR/examples
|
||||
DATA_SAVE_PATH=$BASE_TEMP_DIR/rlhf_data
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
|
||||
if [ ! -d "$BASE_TEMP_DIR" ]; then
|
||||
mkdir "$BASE_TEMP_DIR"
|
||||
echo "Directory created successfully"
|
||||
else
|
||||
echo "Directory already exists"
|
||||
fi
|
||||
|
||||
if [ ! -d "$DATA_SAVE_PATH" ]; then
|
||||
mkdir "$DATA_SAVE_PATH"
|
||||
echo "Directory created successfully"
|
||||
else
|
||||
echo "Directory already exists"
|
||||
fi
|
||||
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
pip install -r $EXAMPLES_DIR/requirements.txt
|
||||
|
||||
get_data_input_dirs() {
|
||||
local data_type=$1
|
||||
if [[ $data_type == "sft" ]]; then
|
||||
echo "$SFT_DATASET"
|
||||
elif [[ $data_type == "prompt" ]]; then
|
||||
echo "$PROMPT_DATASET"
|
||||
elif [[ $data_type == "preference" ]]; then
|
||||
echo "$PREFERENCE_DATASET"
|
||||
else
|
||||
echo "Unknown data type $data_type"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
if [[ $model == "llama" ]]; then
|
||||
echo "$TEST_DIR/llama.json"
|
||||
elif [[ $model == "opt" ]]; then
|
||||
echo "$TEST_DIR/opt.json"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
get_tokenizer_dirs() {
|
||||
local model=$1
|
||||
if [[ $model == "llama" ]]; then
|
||||
echo "hf-internal-testing/llama-tokenizer"
|
||||
elif [[ $model == "opt" ]]; then
|
||||
echo "facebook/opt-125m"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
random_choice() {
|
||||
local arr=("$@")
|
||||
local len=${#arr[@]}
|
||||
local idx=$((RANDOM % len))
|
||||
echo ${arr[$idx]}
|
||||
}
|
||||
|
||||
echo "Prepare dummy data for testing..."
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs sft) \
|
||||
--data_type "sft"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs preference) \
|
||||
--data_type "preference"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs prompt) \
|
||||
--data_type "prompt"
|
||||
|
||||
echo "[Test]: testing prepare_preference_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_preference_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="preference"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type preference \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "[Test]: testing prepare_sft_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_sft_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="sft"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type sft \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "[Test]: testing prepare_prompt_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_prompt_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="prompt"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type prompt \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
69
applications/ColossalChat/tests/test_lora.py
Executable file
69
applications/ColossalChat/tests/test_lora.py
Executable file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models import convert_to_lora_module
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
|
||||
class SimpleNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(SimpleNN, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
def test_overfit():
|
||||
input_size = 1000
|
||||
hidden_size = 200
|
||||
num_classes = 5
|
||||
batch_size = 64
|
||||
learning_rate = 0.01
|
||||
num_epochs = 200
|
||||
|
||||
# Synthesized dataset
|
||||
X = torch.randn(batch_size, input_size)
|
||||
Y = torch.randint(0, num_classes, (batch_size,))
|
||||
|
||||
# Convert to DataLoader
|
||||
dataset = TensorDataset(X, Y)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Build and convert model
|
||||
model = SimpleNN(input_size, hidden_size, num_classes)
|
||||
weight_to_compare = model.fc1.weight.detach().clone()
|
||||
model = convert_to_lora_module(model, lora_rank=30)
|
||||
|
||||
# Loss and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# Train the model
|
||||
for _ in range(num_epochs):
|
||||
for i, (inputs, labels) in enumerate(loader):
|
||||
# Forward pass
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
print(loss)
|
||||
# Backward and optimize
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Check if model has overfitted
|
||||
outputs = model(X)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total = labels.size(0)
|
||||
correct = (predicted == Y).sum().item()
|
||||
assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset")
|
||||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_overfit()
|
97
applications/ColossalChat/tests/test_templating.sh
Executable file
97
applications/ColossalChat/tests/test_templating.sh
Executable file
@@ -0,0 +1,97 @@
|
||||
|
||||
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
|
||||
BASE_TEMP_DIR=$BASE_DIR/temp
|
||||
EXAMPLES_DIR=$BASE_DIR/examples
|
||||
TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
||||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
|
||||
MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi")
|
||||
|
||||
get_pretrain() {
|
||||
local model=$1
|
||||
if [[ $model == "colossal-llama2" ]]; then
|
||||
echo "hpcai-tech/Colossal-LLaMA-2-7b-base"
|
||||
elif [[ $model == "llama2" ]]; then
|
||||
echo "hf-internal-testing/llama-tokenizer"
|
||||
elif [[ $model == "zephyr" ]]; then
|
||||
echo "HuggingFaceH4/zephyr-7b-beta"
|
||||
elif [[ $model == "mistral" ]]; then
|
||||
echo "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
elif [[ $model == "chatGLM2" ]]; then
|
||||
echo "THUDM/chatglm2-6b"
|
||||
elif [[ $model == "Qwen" ]]; then
|
||||
echo "Qwen/Qwen-7B-Chat"
|
||||
elif [[ $model == "Vicuna" ]]; then
|
||||
echo "lmsys/vicuna-7b-v1.5"
|
||||
elif [[ $model == "Yi" ]]; then
|
||||
echo "01-ai/Yi-6B-Chat"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
echo "$CONFIG_DIR/conversation_template/$model.json"
|
||||
}
|
||||
|
||||
# Test SFT data Preparation
|
||||
for model in ${MODELS[@]}; do
|
||||
echo "Testing SFT data templating for $model"
|
||||
SAVE_DIR=$DATA_SAVE_PATH/sft/$model
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
pretrain=$(get_pretrain $model)
|
||||
conversation_template_config=$(get_conversation_template_config $model)
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
|
||||
--tokenizer_dir $pretrain \
|
||||
--conversation_template_config $conversation_template_config \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the SFT data templating for $model"
|
||||
exit 1
|
||||
fi
|
||||
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/sft/test_sft_data.jsonl \
|
||||
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type sft
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the SFT data templating test for $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
# Test DPO/PPO data Preparation
|
||||
for model in ${MODELS[@]}; do
|
||||
echo "Testing DPO/PPO data templating for $model"
|
||||
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
pretrain=$(get_pretrain $model)
|
||||
conversation_template_config=$(get_conversation_template_config $model)
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type preference --data_input_dirs $TEST_DATA_DIR/dpo \
|
||||
--tokenizer_dir $pretrain \
|
||||
--conversation_template_config $conversation_template_config \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the DPO data templating for $model"
|
||||
exit 1
|
||||
fi
|
||||
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
|
||||
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the DPO data templating test for $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
397
applications/ColossalChat/tests/test_train.sh
Executable file
397
applications/ColossalChat/tests/test_train.sh
Executable file
@@ -0,0 +1,397 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xu
|
||||
|
||||
|
||||
NUM_RETRY=3
|
||||
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
|
||||
EXAMPLES_DIR=$BASE_DIR/examples
|
||||
CONFIG_DIR=$BASE_DIR/config
|
||||
TEMP_DIR=$BASE_DIR/temp
|
||||
TEST_DIR=$BASE_DIR/tests
|
||||
MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
||||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
get_pretrain() {
|
||||
local model=$1
|
||||
if [[ $model == "llama" ]]; then
|
||||
echo "nickypro/tinyllama-110M"
|
||||
elif [[ $model == "opt" ]]; then
|
||||
echo "facebook/opt-125m"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
get_tokenizer_dirs() {
|
||||
local model=$1
|
||||
if [[ $model == "llama" ]]; then
|
||||
echo "hf-internal-testing/llama-tokenizer"
|
||||
elif [[ $model == "opt" ]]; then
|
||||
echo "facebook/opt-125m"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
get_conversation_template_config() {
|
||||
local model=$1
|
||||
if [[ $model == "llama" ]]; then
|
||||
echo "$TEST_DIR/llama.json"
|
||||
elif [[ $model == "opt" ]]; then
|
||||
echo "$TEST_DIR/opt.json"
|
||||
else
|
||||
echo "Unknown model $model"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
random_choice() {
|
||||
local arr=("$@")
|
||||
local len=${#arr[@]}
|
||||
local idx=$((RANDOM % len))
|
||||
echo ${arr[$idx]}
|
||||
}
|
||||
|
||||
|
||||
echo "[Test]: testing sft ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# Check if the plugin is either "gemini_auto" or "gemini" and set grad_accu to '1'
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "[Test]: testing reward model ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing ppo ..."
|
||||
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue # gemini_auto plugin doesn't support generation
|
||||
fi
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='4'
|
||||
ebs='8'
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='16'
|
||||
ebs='32'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto and gemini doesn't support generation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
# gemini-auto doesn't support generation
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a prompt_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||
done
|
||||
declare -a ptx_dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
--pretrain $pretrain \
|
||||
--rm_pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--conversation_template_config $conversation_template \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--ptx_dataset ${ptx_dataset[@]} \
|
||||
--ptx_batch_size 1 \
|
||||
--ptx_coef 0.2 \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--lora_rank $lora_rank \
|
||||
--plugin $plugin \
|
||||
--num_episodes 5 \
|
||||
--num_collect_steps 1 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size $ebs \
|
||||
--train_batch_size $bs \
|
||||
--accumulation_steps $grad_accu \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--max_seq_len 10 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing DPO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto doesn't support generation
|
||||
# (need to calculate ref_model logits through forwarding in inference mode)
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
64
applications/ColossalChat/tests/verify_chat_data.py
Normal file
64
applications/ColossalChat/tests/verify_chat_data.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_source",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The raw data file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to_verify_file",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The file that contains the data to be verified",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_type",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="The data type",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Read data
|
||||
data = []
|
||||
with open(args.data_source, "r", encoding="utf8") as f:
|
||||
for line in f.readlines():
|
||||
data.append(json.loads(line))
|
||||
to_verify_data = []
|
||||
with open(args.to_verify_file, "r", encoding="utf8") as f:
|
||||
for line in f.readlines():
|
||||
to_verify_data.append(json.loads(line))
|
||||
|
||||
if args.data_type == "sft":
|
||||
target_lable = [msg["content"].strip() for msg in data[0]["messages"] if msg["from"] == "assistant"]
|
||||
target_negative_label = [msg["content"].strip() for msg in data[0]["messages"] if msg["from"] == "human"]
|
||||
|
||||
# Read to verify file
|
||||
|
||||
to_verify_lable = to_verify_data[0]["labels_decode"]
|
||||
for label in target_lable:
|
||||
assert any([label in s for s in to_verify_lable]), f"Label {label} not in target label {to_verify_lable}"
|
||||
for label in target_negative_label:
|
||||
assert all(
|
||||
[label not in s for s in to_verify_lable]
|
||||
), f"Negative label {label} in target label {to_verify_lable}"
|
||||
elif args.data_type == "dpo":
|
||||
chosen_lable = data[0]["chosen"][0]["content"].strip()
|
||||
rejected_lable = data[0]["rejected"][0]["content"].strip()
|
||||
|
||||
# Read to verify file
|
||||
to_verify_lable_chosen = to_verify_data[0]["chosen_label_decode"]
|
||||
to_verify_lable_rejected = to_verify_data[0]["rejected_label_decode"]
|
||||
assert any(
|
||||
[chosen_lable in s for s in to_verify_lable_chosen]
|
||||
), f"Chosen label {chosen_lable} not in target chosen label {to_verify_lable_chosen}"
|
||||
assert any(
|
||||
[rejected_lable in s for s in to_verify_lable_rejected]
|
||||
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
|
Reference in New Issue
Block a user