mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
add kto
This commit is contained in:
@@ -30,6 +30,8 @@
|
||||
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
||||
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
||||
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
|
||||
- [List of Supported Models](#list-of-supported-models)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Inference example](#inference-example)
|
||||
@@ -744,13 +746,21 @@ with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which i
|
||||
|
||||
|
||||
### Alternative Option For RLHF: Odds Ratio Preference Optimization
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
|
||||
#### ORPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
|
||||
</p>
|
||||
|
||||
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. To use KTO in alignment, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
|
||||
|
||||
#### KTO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
|
||||
</p>
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
||||
@@ -801,6 +811,14 @@ For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption
|
||||
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
|
||||
|
||||
For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=2, VRAM Usage=35241.98 MB
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=38989.37 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
|
||||
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
|
||||
|
||||
## List of Supported Models
|
||||
|
||||
For SFT, we support the following models/series:
|
||||
|
@@ -40,7 +40,13 @@ import random
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from coati.dataset import (
|
||||
setup_conversation_template,
|
||||
supervised_tokenize_sft,
|
||||
tokenize_kto,
|
||||
tokenize_prompt_dataset,
|
||||
tokenize_rlhf,
|
||||
)
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -56,8 +62,8 @@ def main():
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
choices=["sft", "prompt", "preference"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
|
||||
choices=["sft", "prompt", "preference", "kto"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_input_dirs",
|
||||
@@ -204,6 +210,8 @@ def main():
|
||||
preparation_function = tokenize_prompt_dataset
|
||||
elif args.type == "preference":
|
||||
preparation_function = tokenize_rlhf
|
||||
elif args.type == "kto":
|
||||
preparation_function = tokenize_kto
|
||||
else:
|
||||
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
|
||||
|
||||
@@ -228,10 +236,13 @@ def main():
|
||||
keep_in_memory=False,
|
||||
num_proc=min(len(dataset), cpu_count()),
|
||||
)
|
||||
|
||||
dataset = dataset.filter(
|
||||
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
|
||||
)
|
||||
if args.type == "kto":
|
||||
filter_by = "completion"
|
||||
elif args.type == "preference":
|
||||
filter_by = "chosen_input_ids"
|
||||
else:
|
||||
filter_by = "input_ids"
|
||||
dataset = dataset.filter(lambda data: data[filter_by] is not None)
|
||||
|
||||
# Save each jsonl spliced dataset.
|
||||
output_index = "0" * (5 - len(str(index))) + str(index)
|
||||
|
@@ -0,0 +1,14 @@
|
||||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type kto \
|
||||
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/kto_format/data \
|
||||
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 1024
|
@@ -1,13 +1,13 @@
|
||||
SAVE_DIR=""
|
||||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/sft \
|
||||
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
|
104
applications/ColossalChat/examples/inference/round.txt
Normal file
104
applications/ColossalChat/examples/inference/round.txt
Normal file
@@ -0,0 +1,104 @@
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
|
||||
|
||||
==========
|
372
applications/ColossalChat/examples/training_scripts/train_kto.py
Executable file
372
applications/ColossalChat/examples/training_scripts/train_kto.py
Executable file
@@ -0,0 +1,372 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import KTOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
num_desirable = 0
|
||||
num_undesirable = 0
|
||||
for i in range(len(train_dataset)):
|
||||
if train_dataset[i]["label"]:
|
||||
num_desirable += 1
|
||||
else:
|
||||
num_undesirable += 1
|
||||
logger.info(f"Dataset Statistics:\nDesirable: {num_desirable}\nUndesirable: {num_undesirable}")
|
||||
|
||||
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
|
||||
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
|
||||
if actual_ratio <= 1:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase desirable weight or decrease undesirable weight."
|
||||
)
|
||||
elif actual_ratio > 4 / 3:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please decrease desirable weight or increase undesirable weight."
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
actor=model,
|
||||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
desirable_weight=args.desirable_weight,
|
||||
undesirable_weight=args.undesirable_weight,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
|
||||
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
|
||||
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
61
applications/ColossalChat/examples/training_scripts/train_kto.sh
Executable file
61
applications/ColossalChat/examples/training_scripts/train_kto.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="kto"
|
||||
PARENT_SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto/checkpoint" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="/home/nvme-share/home/yeanbang/data/model/hh_rlhf_sheared_llamasft-2024-07-17-07-29-29/modeling" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00000
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00001
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00002
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00003
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00004
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00005
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00006
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00007
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00008
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
--lr 1e-5 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 1024 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint
|
@@ -42,7 +42,6 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
|
@@ -15,22 +15,22 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
PARENT_SAVE_DIR="/home/nvme-share/home/yeanbang/data/model/hh_rlhf_sheared_llama" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft/log" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
/Your/SFT/Data/arrow/part-00000
|
||||
/Your/SFT/Data/arrow/part-00001
|
||||
/Your/SFT/Data/arrow/part-00002
|
||||
/Your/SFT/Data/arrow/part-00003
|
||||
/Your/SFT/Data/arrow/part-00004
|
||||
/Your/SFT/Data/arrow/part-00005
|
||||
/Your/SFT/Data/arrow/part-00006
|
||||
/Your/SFT/Data/arrow/part-00007
|
||||
/Your/SFT/Data/arrow/part-00008
|
||||
/Your/SFT/Data/arrow/part-00009
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00000
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00001
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00002
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00003
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00004
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00005
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00006
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00007
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00008
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
@@ -44,17 +44,16 @@ echo $(which python)
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
--save_interval 2000 \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin zero2 \
|
||||
--batch_size 4 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--accumulation_steps 2 \
|
||||
--lr 5e-5 \
|
||||
--max_len 4096 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb \
|
||||
--use_flash_attn
|
||||
|
Reference in New Issue
Block a user