diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index d6894a7da..cd6c3cfa4 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -67,7 +67,7 @@ class BaseConsumer: and "num_microbatches" not in self.plugin_config and "microbatch_size" not in self.plugin_config ): - plugin_config["microbatch_size"] = self.minibatch_size + plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1)) plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 3da4a4f47..877ff98ec 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -7,7 +7,7 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.reward.reward_fn import math_reward_fn +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum @@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer): and "num_microbatches" not in plugin_config and "microbatch_size" not in plugin_config ): - plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2) + plugin_config["microbatch_size"] = max( + 1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1) + ) super().__init__( num_producers, num_episodes, @@ -131,7 +133,12 @@ class GRPOConsumer(BaseConsumer): k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] } self.reward_model = VerifiableReward( - reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs + reward_fns=[ + math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn + ], + tokenizer=self.tokenizer, + tags=response_format_tags, + **reward_model_kwargs, ) self.global_step = 0 self.use_wandb = use_wandb diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index fd9129130..b68c1a92f 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,6 +1,6 @@ import torch -from .reward_utils import extract_solution, validate_response_structure +from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): @@ -70,3 +70,43 @@ def gsm8k_reward_fn(input_ids, **kwargs): if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): reward = reward + 9.0 return reward + + +def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): + tokenizer = kwargs["tokenizer"] + soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) + format_score = 0.0 + acc_score = 10.0 + reward = torch.tensor(0.0) + format_acc = torch.tensor(0.0) + ans_acc = torch.tensor(0.0) + s, e = response_idx[0], response_idx[1] + + length_reward = 0.0 + if soft_over_length_punishment: + max_length = kwargs.get("max_length", 1024 * 4) + cache_length = kwargs.get("cache_length", 512) + res_length = e.item() - s.item() + 1 + if max_length - cache_length < res_length < max_length: + length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + + if gt_answer is None: + return reward + + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) + final_answer = extract_boxed_solution(decoded_final_answer) + format_valid = final_answer is not None + # Check format accuracy + if format_valid: + format_acc += 1 + reward += format_score + + # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid + if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower(): + ans_acc += 1 + reward += acc_score + + reward = reward + length_reward + + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_utils.py b/applications/ColossalChat/coati/distributed/reward/reward_utils.py index c1e73d4b9..ffc220846 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_utils.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_utils.py @@ -74,3 +74,51 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: final_answer = matches[-1].group(1).strip() return final_answer, solution_str + + +def extract_boxed_solution(text: str) -> Optional[str]: + """ + Modified from: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3 + Retrieves the content from the last occurrence of `\boxed{}` in a LaTeX-like string. + + Args: + text (str): A string potentially containing LaTeX-style boxed expressions. + + Returns: + Optional[str]: The text inside the final `\boxed{}` if successfully extracted; + returns `None` if no properly closed box is found. + + Examples: + >>> extract_boxed_solution("The answer is \\boxed{42}.") + '42' + >>> extract_boxed_solution("Here is an unmatched \\boxed{42") + None + """ + try: + # Find the last occurrence of "\boxed{" + start_idx = text.rindex("\\boxed{") + # Move past "\boxed{" to find the start of the content + content_start = start_idx + len("\\boxed{") + open_braces = 1 + pos = content_start + + # Traverse the string to find the matching closing brace + while open_braces > 0 and pos < len(text): + if text[pos] == "{": + open_braces += 1 + elif text[pos] == "}": + open_braces -= 1 + pos += 1 + + # If all braces are matched, extract and return the content + if open_braces == 0: + return text[content_start : pos - 1].strip() + else: + return None + + except ValueError: + # "\boxed{" not found + return None + except Exception: + # Any other unexpected error + return None diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 2dcd8c0db..a3ed00f88 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -86,6 +86,14 @@ if __name__ == "__main__": parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") + parser.add_argument( + "-rt", + "--reward-type", + type=str, + default="think_answer_tags", + choices=["think_answer_tags", "boxed"], + help="Reward type for GRPO.", + ) # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") @@ -136,8 +144,8 @@ if __name__ == "__main__": max_length=args.max_new_tokens + args.max_prompt_tokens, do_sample=True, max_new_tokens=None, - early_stopping=False, - stop_strings=[""], + early_stopping=False if args.reward_type == "think_answer_tags" else True, + stop_strings=[""] if args.reward_type == "think_answer_tags" else None, ) ) elif args.backend == "vllm": @@ -153,9 +161,9 @@ if __name__ == "__main__": generate_config.update( dict( max_tokens=args.max_new_tokens, # max new tokens - ignore_eos=True, + ignore_eos=True if args.reward_type == "think_answer_tags" else False, include_stop_str_in_output=True, - stop=[""], + stop=[""] if args.reward_type == "think_answer_tags" else None, ) ) else: @@ -168,6 +176,7 @@ if __name__ == "__main__": "train_microbatch_size": args.train_microbatch_size, "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", + "reward_fn_type": args.reward_type, } elif args.algo == "DAPO": # DAPO variant settings @@ -185,6 +194,7 @@ if __name__ == "__main__": "max_length": args.max_new_tokens + args.max_prompt_tokens, "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, + "reward_fn_type": args.reward_type, } else: raise ValueError(f"Unsupported algorithm: {args.algo}") @@ -210,12 +220,17 @@ if __name__ == "__main__": train_model_config=train_model_config, grpo_config=grpo_config, plugin_config={ - "pp_size": 2, - "tp_size": 2, - "microbatch_size": args.train_microbatch_size // 2, - "zero_stage": 1, - "max_norm": 1.0, - }, # for tp + pp + "zero_stage": 2, + }, # for zero + # plugin_config={ + # "tp_size": 2, + # "pp_size": 2, + # "microbatch_size": max( + # 1, args.train_microbatch_size // 2 + # ), # microbatch size should be set to train_microbatch_size // pp_size + # "zero_stage": 0, + # "max_norm": 1.0, + # }, # for pp, tp inference_backend=args.backend, master_addr="localhost", master_port=args.master_port,