Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-02-18 09:43:36 +08:00
committed by GitHub
parent ce0ec40811
commit d20c8ffd97
39 changed files with 1995 additions and 277 deletions

View File

@@ -13,9 +13,18 @@ from coati.dataset import (
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.models import (
Critic,
LoraConfig,
RewardModel,
RLVRRewardModel,
convert_to_lora_module,
disable_dropout,
lora_manager,
)
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
@@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
@@ -61,28 +79,36 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
if not args.no_neural_reward_model:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
critic = Critic(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain)
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if not args.no_neural_reward_model:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
critic = Critic(args.rm_pretrain)
if args.lora_config is not None:
@@ -112,6 +138,9 @@ def train(args):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
@@ -245,7 +274,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
custom_policy=get_autopolicy(critic.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -284,7 +313,8 @@ def train(args):
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
rm_booster = Booster(plugin=custom_plugin)
if not args.no_neural_reward_model:
rm_booster = Booster(plugin=custom_plugin)
critic_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
@@ -302,7 +332,28 @@ def train(args):
lr_scheduler=critic_lr_scheduler,
dataloader=train_prompt_dataloader,
)
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
if not args.no_neural_reward_model:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_fn_list.append(eval(reward_fn))
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
@@ -481,9 +532,11 @@ if __name__ == "__main__":
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--no_neural_reward_model", default=False, action="store_true")
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)