Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-02-18 09:43:36 +08:00
committed by GitHub
parent ce0ec40811
commit d20c8ffd97
39 changed files with 1995 additions and 277 deletions

View File

@@ -27,6 +27,7 @@
- [Reward](#reward)
- [KL Divergence](#approximate-kl-divergence)
- [Note on PPO Training](#note-on-ppo-training)
- [GRPO Training and DeepSeek R1 reproduction]
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
@@ -725,6 +726,75 @@ Answer: The causes of this problem are two-fold. Check your reward model, make s
#### Q4: Generation is garbage
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss.
## GRPO Training and DeepSeek R1 reproduction
We support GRPO (Group Relative Policy Optimization), which is the reinforcement learning algorithm used in DeepSeek R1 paper. In this section, we will walk through GRPO training with an example trying to reproduce Deepseek R1's results in mathematical problem solving.
### GRPO Model Selection
We finally select the base version of [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B). We also did experiments on the instruct version [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) but the later one fails to explore more diversed output. We recommend to use base models (without SFT) and use a few SFT steps (see [SFT section](#rlhf-training-stage1---supervised-instructs-tuning)) to correct the base model's output format before GRPO.
### Reinforcement Learning with Verifiable Reward
Both the PPO and the GRPO support reinforcement learning with verifiable reward (RLVR). In this experiment on mathematical problem solving, we define the reward function as following, in the following definition, forward is correct if there are exactly one pair of <think></think>, <answer></answer> tags in the response and the order of the tags is correct.
- reward=0, if format is incorrect.
- reward=1, if format is correct but the answer doesn't match the ground truth answer exactly.
- reward=10, if format is correct and the answer match the ground truth answer exactly.
### Step 1: Data Collection & Preparation
For GPRO training, you only need the prompt dataset. Please follow the instruction in the [prompt dataset preparation](#rlhf-training-stage3---proximal-policy-optimization) to prepare the prompt data for GPRO training. In our reproduction experiment, we use the [qwedsacf/competition_math dataset](https://huggingface.co/datasets/qwedsacf/competition_math), which is available on Huggingface.
### Step 2: Training
You can run the [train_grpo.sh](./training_scripts/train_grpo.sh) to start GRPO training. The script share most of its arguments with the PPO script (please refer to the [PPO training section](#step-3-training) for more details). Here are some unique arguments for GRPO.
```bash
--num_generations 8 \ # number of roll outs to collect for each prompt
--inference_batch_size 8 \ # batch size used during roll out
--logits_forward_batch_size 1 \ # batch size used to calculate logits for GRPO training
--initial_temperature \ # initial temperature for annealing algorithm
--final_temperature \ # final temperature for annealing algorithm
```
As the GRPO requires to collect a group of response from each prompt (usually greater than 8), the effective batch size will satisfy the following constraints,
- Without tensor parallelism,
```
experience buffer size
= num_process * num_collect_steps * experience_batch_size * num_generations
= train_batch_size * accumulation_steps * num_process
```
- With tensor parallelism,
```
num_tp_group = num_process / tp
experience buffer size
= num_tp_group * num_collect_steps * experience_batch_size * num_generations
= train_batch_size * accumulation_steps * num_tp_group
```
During roll out, we perform rebatching to prevent out of memory both before roll out and before calculating logits. Please choose a proper setting for the "inference_batch_size" and the "logits_forward_batch_size" based on your device.
### GRPO Result
#### Reward
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/reward.png">
</p>
#### Response Length
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost.png">
</p>
#### Response Length Distribution (After Training)
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost_eval.png">
</p>
#### Sample Response
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/res.png">
</p>
#### Note of Speed
Currently, our PPO and GRPO pipeline are still under development. The speed is largely limited by the roll out speed as we use naive generation without any acceleration.
## Alternative Option For RLHF: Direct Preference Optimization

View File

@@ -11,4 +11,4 @@ python prepare_dataset.py --type prompt \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024
--max_length 300

View File

@@ -1,4 +1,4 @@
pandas>=1.4.1
sentencepiece
colossalai==0.4.0
colossalai==0.4.7
prompt_toolkit

View File

@@ -0,0 +1,494 @@
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
import torch.distributed as dist
from coati.dataset import (
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import LoraConfig, RewardModel, RLVRRewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.trainer import GRPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
actor = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
if args.rm_pretrain:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
if args.rm_pretrain:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if args.lora_config is not None:
actor = convert_to_lora_module(actor, lora_config=lora_config)
for name, module in actor.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
lora_manager.able_to_merge = False
# Disable dropout
disable_dropout(actor)
if args.grad_checkpoint:
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
if os.path.exists(args.conversation_template_config):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None
else:
raise ValueError("Conversation template config is not provided or incorrect")
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "left" # left padding for generation (online learning)
# configure generation config
actor.generation_config.update(
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
)
# configure optimizer
coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
actor_optim = HybridAdam(
model_params=actor.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(0.025 * args.num_episodes)
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
actor_lr_scheduler = CosineAnnealingWarmupLR(
optimizer=actor_optim,
total_steps=args.num_episodes,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
args.use_flash_attn = False
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
if args.rm_pretrain:
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
if args.plugin != "3d" and args.rm_pretrain:
custom_plugin = plugin
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = plugin.prepare_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = plugin.prepare_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
train_pretrain_dataloader = None
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
if args.rm_pretrain:
rm_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
dataloader=train_prompt_dataloader,
)
if args.rm_pretrain:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
sampler_start_idx = 0
start_step = 0
if args.rm_checkpoint_path is not None:
if "modeling" in args.rm_checkpoint_path:
rm_booster.load_model(reward_model, args.rm_checkpoint_path)
else:
_, _, _ = load_checkpoint(
load_dir=args.rm_checkpoint_path,
booster=rm_booster,
model=reward_model,
optimizer=None,
lr_scheduler=None,
)
coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
actor_booster.load_model(actor, args.checkpoint_path)
ref_booster.load_model(ref_model, args.checkpoint_path)
coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
else:
_, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=actor_booster,
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
)
_, _, _ = load_checkpoint(load_dir=args.checkpoint_path, booster=ref_booster, model=ref_model)
assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
# configure trainer
trainer = GRPOTrainer(
actor_booster,
actor,
reward_model,
ref_model,
actor_optim,
actor_lr_scheduler,
tokenizer=tokenizer,
stop_token_ids=[stop_ids],
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size,
buffer_limit=args.num_collect_steps * args.experience_batch_size * args.num_generations,
max_length=args.max_length,
use_cache=True,
do_sample=True,
apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
top_k=50,
use_tp=args.tp > 1,
num_generations=args.num_generations,
inference_batch_size=args.inference_batch_size,
logits_forward_batch_size=args.logits_forward_batch_size,
offload_inference_models="gemini" not in args.plugin,
coordinator=coordinator,
max_tokens_thinking=args.max_tokens_thinking if args.max_tokens_thinking else args.max_length - 100,
temperature_annealing_config={
"start_temperature": args.initial_temperature,
"end_temperature": args.final_temperature,
"annealing_warmup_steps": min(100, int(args.num_episodes / 6)),
"annealing_steps": min(600, int(args.num_episodes / 2)),
},
# Hack: some old model's default update_model_kwargs_fn/prepare_inputs_fn may doesn't work due to version conflict with transformers, you can overwrite them
# update_model_kwargs_fn=update_model_kwargs_fn,
# prepare_inputs_fn = None
)
trainer.fit(
num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps,
prompt_dataloader=train_prompt_dataloader,
pretrain_dataloader=train_pretrain_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
lora_manager.able_to_merge = True
actor.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final actor model checkpoint")
actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_dataset", nargs="+", default=[])
parser.add_argument("--ptx_dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument(
"--conversation_template_config",
type=str,
default=None,
help="Path \
to save conversation template config files.",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--inference_batch_size", type=int, default=None)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--logits_forward_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=4)
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-6)
parser.add_argument("--kl_coef", type=float, default=0.7)
parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_tokens_thinking", type=int, default=2000)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--initial_temperature", type=float, default=1.0)
parser.add_argument("--final_temperature", type=float, default=0.9)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
train(args)

View File

@@ -0,0 +1,86 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="PPO-RLVR"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file
LOGDIR=""
declare -a prompt_dataset=(
YOUR/PROMPT/DATA/DIR/arrow/part-00000
YOUR/PROMPT/DATA/DIR/arrow/part-00001
YOUR/PROMPT/DATA/DIR/arrow/part-00002
YOUR/PROMPT/DATA/DIR/arrow/part-00003
YOUR/PROMPT/DATA/DIR/arrow/part-00004
YOUR/PROMPT/DATA/DIR/arrow/part-00005
YOUR/PROMPT/DATA/DIR/arrow/part-00006
YOUR/PROMPT/DATA/DIR/arrow/part-00007
YOUR/PROMPT/DATA/DIR/arrow/part-00008
YOUR/PROMPT/DATA/DIR/arrow/part-00009
)
declare -a ptx_dataset=(
YOUR/SFT/DATA/DIR/arrow/part-00000
YOUR/SFT/DATA/DIR/arrow/part-00001
YOUR/SFT/DATA/DIR/arrow/part-00002
YOUR/SFT/DATA/DIR/arrow/part-00003
YOUR/SFT/DATA/DIR/arrow/part-00004
YOUR/SFT/DATA/DIR/arrow/part-00005
YOUR/SFT/DATA/DIR/arrow/part-00006
YOUR/SFT/DATA/DIR/arrow/part-00007
YOUR/SFT/DATA/DIR/arrow/part-00008
YOUR/SFT/DATA/DIR/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --num_nodes 1 --hostfile ./hostfile train_grpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--prompt_dataset ${prompt_dataset[@]} \
--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \
--ptx_coef 0.0 \
--plugin "zero2_cpu" \
--reward_functions math_competition_reward_fn \
--save_interval 250 \
--save_path $SAVE_DIR \
--num_episodes 100 \
--num_collect_steps 8 \
--num_update_steps 1 \
--experience_batch_size 1 \
--train_batch_size 4 \
--inference_batch_size 8 \
--logits_forward_batch_size 2 \
--accumulation_steps 4 \
--lr 1e-6 \
--mixed_precision "bf16" \
--grad_clip 0.1\
--weight_decay 0.01 \
--kl_coef 0.01 \
--warmup_steps 40 \
--max_length 2000 \
--max_seq_len 1700 \
--log_dir $LOGDIR \
--use_flash_attn \
--grad_checkpoint

View File

@@ -13,9 +13,18 @@ from coati.dataset import (
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.models import (
Critic,
LoraConfig,
RewardModel,
RLVRRewardModel,
convert_to_lora_module,
disable_dropout,
lora_manager,
)
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
@@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
@@ -61,28 +79,36 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
if not args.no_neural_reward_model:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
critic = Critic(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain)
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if not args.no_neural_reward_model:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
critic = Critic(args.rm_pretrain)
if args.lora_config is not None:
@@ -112,6 +138,9 @@ def train(args):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
@@ -245,7 +274,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
custom_policy=get_autopolicy(critic.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -284,7 +313,8 @@ def train(args):
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
rm_booster = Booster(plugin=custom_plugin)
if not args.no_neural_reward_model:
rm_booster = Booster(plugin=custom_plugin)
critic_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
@@ -302,7 +332,28 @@ def train(args):
lr_scheduler=critic_lr_scheduler,
dataloader=train_prompt_dataloader,
)
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
if not args.no_neural_reward_model:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_fn_list.append(eval(reward_fn))
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
@@ -481,9 +532,11 @@ if __name__ == "__main__":
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--no_neural_reward_model", default=False, action="store_true")
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)