mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
1
applications/ColossalChat/examples/training_scripts/hostfile
Executable file
1
applications/ColossalChat/examples/training_scripts/hostfile
Executable file
@@ -0,0 +1 @@
|
||||
10.20.1.82
|
326
applications/ColossalChat/examples/training_scripts/train_dpo.py
Executable file
326
applications/ColossalChat/examples/training_scripts/train_dpo.py
Executable file
@@ -0,0 +1,326 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import (
|
||||
DataCollatorForPreferenceDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import DPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
parallel_output=False,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.enable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_dataloader = setup_distributed_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
actor=model,
|
||||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--enable_reference_model", type=bool, default=True)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
62
applications/ColossalChat/examples/training_scripts/train_dpo.sh
Executable file
62
applications/ColossalChat/examples/training_scripts/train_dpo.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
# export CUDA_VISIBLE_DEVICES=6
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
YOUR/DATA/DIR/arrow/part-00000
|
||||
YOUR/DATA/DIR/arrow/part-00001
|
||||
YOUR/DATA/DIR/arrow/part-00002
|
||||
YOUR/DATA/DIR/arrow/part-00003
|
||||
YOUR/DATA/DIR/arrow/part-00004
|
||||
YOUR/DATA/DIR/arrow/part-00005
|
||||
YOUR/DATA/DIR/arrow/part-00006
|
||||
YOUR/DATA/DIR/arrow/part-00007
|
||||
YOUR/DATA/DIR/arrow/part-00008
|
||||
YOUR/DATA/DIR/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--batch_size 2 \
|
||||
--lr 1e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 100 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
506
applications/ColossalChat/examples/training_scripts/train_ppo.py
Executable file
506
applications/ColossalChat/examples/training_scripts/train_ppo.py
Executable file
@@ -0,0 +1,506 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import (
|
||||
DataCollatorForPromptDataset,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_conversation_template,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
booster_policy = None
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
actor = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
local_files_only=True,
|
||||
)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
local_files_only=True,
|
||||
)
|
||||
reward_model = RewardModel(
|
||||
args.rm_pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
critic = Critic(
|
||||
args.rm_pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
reward_model = RewardModel(args.rm_pretrain)
|
||||
critic = Critic(args.rm_pretrain)
|
||||
# Disable dropout
|
||||
disable_dropout(actor)
|
||||
disable_dropout(critic)
|
||||
|
||||
if args.tp > 1:
|
||||
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
|
||||
raise ValueError("Reward model and critic model must have the same architecture")
|
||||
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
|
||||
if args.lora_rank > 0:
|
||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
actor.gradient_checkpointing_enable()
|
||||
critic.model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if os.path.exists(args.conversation_template_config):
|
||||
with open(args.conversation_template_config, "r", encoding="utf8") as f:
|
||||
conversation_template_config = json.load(f)
|
||||
dist.barrier()
|
||||
conversation_template = setup_conversation_template(
|
||||
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
|
||||
)
|
||||
stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None
|
||||
else:
|
||||
raise ValueError("Conversation template config is not provided or incorrect")
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
tokenizer.padding_side = "left" # left padding for generation (online learning)
|
||||
|
||||
# configure generation config
|
||||
actor.generation_config.update(
|
||||
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# configure optimizer
|
||||
coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
|
||||
actor_optim = HybridAdam(
|
||||
model_params=actor.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"setting up optimizer for critic: lr={args.lr}, weight_decay={args.weight_decay}")
|
||||
critic_optim = HybridAdam(
|
||||
model_params=critic.parameters(),
|
||||
lr=args.critic_lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
|
||||
train_prompt_dataloader = setup_distributed_dataloader(
|
||||
dataset=train_prompt_dataset,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
|
||||
if len(args.ptx_dataset) > 0:
|
||||
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_pretrain_dataloader = setup_distributed_dataloader(
|
||||
dataset=train_ptx_dataset,
|
||||
batch_size=args.ptx_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
else:
|
||||
train_pretrain_dataloader = None
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(0.025 * args.num_episodes)
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
actor_lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=actor_optim,
|
||||
total_steps=args.num_episodes,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
critic_lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=critic_optim,
|
||||
total_steps=args.num_episodes,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
parallel_output=False,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
custom_plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
parallel_output=False,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
if args.plugin != "3d":
|
||||
custom_plugin = plugin
|
||||
|
||||
actor_booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
rm_booster = Booster(plugin=custom_plugin)
|
||||
critic_booster = Booster(plugin=custom_plugin)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
|
||||
model=actor,
|
||||
optimizer=actor_optim,
|
||||
lr_scheduler=actor_lr_scheduler,
|
||||
dataloader=train_prompt_dataloader,
|
||||
)
|
||||
|
||||
critic, critic_optim, _, _, critic_lr_scheduler = critic_booster.boost(
|
||||
model=critic,
|
||||
optimizer=critic_optim,
|
||||
lr_scheduler=critic_lr_scheduler,
|
||||
dataloader=train_prompt_dataloader,
|
||||
)
|
||||
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
|
||||
if args.rm_checkpoint_path is not None:
|
||||
if "modeling" in args.rm_checkpoint_path:
|
||||
rm_booster.load_model(reward_model, args.rm_checkpoint_path)
|
||||
else:
|
||||
_, _, _ = load_checkpoint(
|
||||
load_dir=args.rm_checkpoint_path,
|
||||
booster=rm_booster,
|
||||
model=reward_model,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
actor_booster.load_model(actor, args.checkpoint_path)
|
||||
ref_booster.load_model(ref_model, args.checkpoint_path)
|
||||
coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
|
||||
else:
|
||||
_, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=actor_booster,
|
||||
model=actor,
|
||||
optimizer=actor_optim,
|
||||
lr_scheduler=actor_lr_scheduler,
|
||||
)
|
||||
_, _, _ = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=ref_booster,
|
||||
model=ref_model,
|
||||
optimizer=critic_optim,
|
||||
lr_scheduler=critic_lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
if args.critic_checkpoint_path is not None:
|
||||
if "modeling" in args.critic_checkpoint_path:
|
||||
critic_booster.load_model(critic, args.critic_checkpoint_path)
|
||||
else:
|
||||
_, _, _ = load_checkpoint(
|
||||
load_dir=args.critic_checkpoint_path,
|
||||
booster=critic_booster,
|
||||
model=critic,
|
||||
optimizer=critic_optim,
|
||||
lr_scheduler=critic_lr_scheduler,
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded critic checkpoint {args.critic_checkpoint_path}")
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
actor_booster,
|
||||
critic_booster,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
ref_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
actor_lr_scheduler,
|
||||
critic_lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
stop_token_ids=stop_ids,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=args.num_collect_steps * args.experience_batch_size,
|
||||
max_length=args.max_length,
|
||||
max_new_tokens=args.max_seq_len,
|
||||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
save_dir=args.save_path,
|
||||
save_interval=args.save_interval,
|
||||
top_k=50,
|
||||
use_tp=args.tp > 1,
|
||||
offload_inference_models="gemini" not in args.plugin,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
num_episodes=args.num_episodes,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
num_update_steps=args.num_update_steps,
|
||||
prompt_dataloader=train_prompt_dataloader,
|
||||
pretrain_dataloader=train_pretrain_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
actor.eval()
|
||||
critic.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final actor model checkpoint")
|
||||
actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
|
||||
)
|
||||
coordinator.print_on_master("Start saving final critic model checkpoint")
|
||||
critic_booster.save_model(critic, os.path.join(trainer.critic_save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final critic model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
|
||||
)
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt_dataset", nargs="+", default=[])
|
||||
parser.add_argument("--ptx_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conversation_template_config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path \
|
||||
to save conversation template config files.",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||
parser.add_argument("--num_episodes", type=int, default=1)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=2)
|
||||
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||
parser.add_argument("--save_interval", type=int, default=1000)
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=16)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_train_bias", type=str, default="none")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=9e-6)
|
||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.0)
|
||||
parser.add_argument("--max_length", type=int, default=2048)
|
||||
parser.add_argument("--max_seq_len", type=int, default=256)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
82
applications/ColossalChat/examples/training_scripts/train_ppo.sh
Executable file
82
applications/ColossalChat/examples/training_scripts/train_ppo.sh
Executable file
@@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME="ppo"
|
||||
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
REWARD_MODEL_PATH="" # local reward model path (from RLHF step 2: Train Reward Model)
|
||||
CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file
|
||||
|
||||
declare -a prompt_dataset=(
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00000
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00001
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00002
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00003
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00004
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00005
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00006
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00007
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00008
|
||||
YOUR/PROMPT/DATA/DIR/arrow/part-00009
|
||||
)
|
||||
|
||||
declare -a ptx_dataset=(
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00000
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00001
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00002
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00003
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00004
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00005
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00006
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00007
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00008
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--rm_pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--rm_checkpoint_path $REWARD_MODEL_PATH \
|
||||
--prompt_dataset ${prompt_dataset[@]} \
|
||||
--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \
|
||||
--ptx_coef 0.0 \
|
||||
--plugin "zero2" \
|
||||
--save_interval 500 \
|
||||
--save_path $SAVE_DIR \
|
||||
--num_episodes 2000 \
|
||||
--num_collect_steps 2 \
|
||||
--num_update_steps 1 \
|
||||
--experience_batch_size 4 \
|
||||
--train_batch_size 4 \
|
||||
--accumulation_steps 2 \
|
||||
--lr 9e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 0.1\
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 40 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
342
applications/ColossalChat/examples/training_scripts/train_rm.py
Executable file
342
applications/ColossalChat/examples/training_scripts/train_rm.py
Executable file
@@ -0,0 +1,342 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import (
|
||||
DataCollatorForPreferenceDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
booster_policy = None
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = RewardModel(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = RewardModel(args.pretrain)
|
||||
|
||||
if args.tp > 1:
|
||||
if model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
parallel_output=False,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure loss function
|
||||
if args.loss_fn == "log_sig":
|
||||
loss_fn = LogSigLoss()
|
||||
elif args.loss_fn == "log_exp":
|
||||
loss_fn = LogExpLoss()
|
||||
else:
|
||||
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
train_dataloader = setup_distributed_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = RewardModelTrainer(
|
||||
model,
|
||||
booster,
|
||||
optim,
|
||||
lr_scheduler,
|
||||
tokenizer,
|
||||
loss_fn=loss_fn,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
61
applications/ColossalChat/examples/training_scripts/train_rm.sh
Executable file
61
applications/ColossalChat/examples/training_scripts/train_rm.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME="rm"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00000
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00001
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00002
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00003
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00004
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00005
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00006
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00007
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00008
|
||||
YOUR/PREFERENCE/DATA/DIR/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
--lr 5e-6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 40 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
311
applications/ColossalChat/examples/training_scripts/train_sft.py
Executable file
311
applications/ColossalChat/examples/training_scripts/train_sft.py
Executable file
@@ -0,0 +1,311 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, load_tokenized_dataset, setup_distributed_dataloader
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
|
||||
)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
|
||||
train_dataloader = setup_distributed_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Flash attention will be disabled because it does NOT support fp32.
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
# model = model.to(get_current_device())
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
booster=booster,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_path,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=None,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
|
||||
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--save_path", type=str, default="output")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
59
applications/ColossalChat/examples/training_scripts/train_sft.sh
Executable file
59
applications/ColossalChat/examples/training_scripts/train_sft.sh
Executable file
@@ -0,0 +1,59 @@
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00000
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00001
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00002
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00003
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00004
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00005
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00006
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00007
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00008
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin zero2 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 2e-5 \
|
||||
--max_len 2048 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
Reference in New Issue
Block a user