mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -13,9 +13,18 @@ from coati.dataset import (
|
||||
load_tokenized_dataset,
|
||||
setup_conversation_template,
|
||||
)
|
||||
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
||||
from coati.models import (
|
||||
Critic,
|
||||
LoraConfig,
|
||||
RewardModel,
|
||||
RLVRRewardModel,
|
||||
convert_to_lora_module,
|
||||
disable_dropout,
|
||||
lora_manager,
|
||||
)
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from coati.utils.reward_score import *
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
@@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# default settings for response format tags, overwrite it in chat_template definition if needed
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
|
||||
|
||||
def train(args):
|
||||
global response_format_tags
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
@@ -61,28 +79,36 @@ def train(args):
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
local_files_only=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
local_files_only=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
reward_model = RewardModel(
|
||||
args.rm_pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
if not args.no_neural_reward_model:
|
||||
reward_model = RewardModel(
|
||||
args.rm_pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
critic = Critic(
|
||||
args.rm_pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
reward_model = RewardModel(args.rm_pretrain)
|
||||
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain, local_files_only=True, trust_remote_code=True
|
||||
)
|
||||
if not args.no_neural_reward_model:
|
||||
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
|
||||
critic = Critic(args.rm_pretrain)
|
||||
|
||||
if args.lora_config is not None:
|
||||
@@ -112,6 +138,9 @@ def train(args):
|
||||
with open(args.conversation_template_config, "r", encoding="utf8") as f:
|
||||
conversation_template_config = json.load(f)
|
||||
dist.barrier()
|
||||
if "response_format_tags" in conversation_template_config:
|
||||
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
|
||||
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
|
||||
conversation_template = setup_conversation_template(
|
||||
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
|
||||
)
|
||||
@@ -245,7 +274,7 @@ def train(args):
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=get_autopolicy(reward_model.model),
|
||||
custom_policy=get_autopolicy(critic.model),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
@@ -284,7 +313,8 @@ def train(args):
|
||||
|
||||
actor_booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
rm_booster = Booster(plugin=custom_plugin)
|
||||
if not args.no_neural_reward_model:
|
||||
rm_booster = Booster(plugin=custom_plugin)
|
||||
critic_booster = Booster(plugin=custom_plugin)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
@@ -302,7 +332,28 @@ def train(args):
|
||||
lr_scheduler=critic_lr_scheduler,
|
||||
dataloader=train_prompt_dataloader,
|
||||
)
|
||||
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
|
||||
if not args.no_neural_reward_model:
|
||||
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
|
||||
else:
|
||||
if args.reward_functions:
|
||||
reward_fn_list = []
|
||||
for reward_fn in args.reward_functions:
|
||||
"""
|
||||
To define custom reward function, you can define your functions under:
|
||||
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
|
||||
and use it here by mofiying the following line:
|
||||
"""
|
||||
if reward_fn == "gsm8k_reward_fn":
|
||||
reward_fn_list.append(gsm8k_reward_fn)
|
||||
elif reward_fn == "math_competition_reward_fn":
|
||||
reward_fn_list.append(math_competition_reward_fn)
|
||||
else:
|
||||
raise ValueError(f"Unknown reward function {reward_fn}")
|
||||
reward_fn_list.append(eval(reward_fn))
|
||||
reward_model = RLVRRewardModel(
|
||||
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
|
||||
)
|
||||
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
@@ -481,9 +532,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--no_neural_reward_model", default=False, action="store_true")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
|
||||
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
|
||||
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||
parser.add_argument("--num_episodes", type=int, default=1)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=2)
|
||||
|
Reference in New Issue
Block a user