From 03b41d6fb5f49b724e0a50d7b6bb3affee173794 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 16 May 2025 18:04:38 +0800 Subject: [PATCH 1/3] upgrade reward functions --- .../coati/distributed/grpo_consumer.py | 4 +- .../coati/distributed/reward/reward_fn.py | 143 ++++++++++++++---- applications/ColossalChat/rl_example.py | 3 + 3 files changed, 123 insertions(+), 27 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index eae4ff54e..8de8b774e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -127,7 +127,9 @@ class GRPOConsumer(BaseConsumer): "answer_end": {"text": "", "num_occur": 1}, } reward_model_kwargs = { - k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] } self.reward_model = VerifiableReward( reward_fns=[ diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 14d340dc4..a4042ae97 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,7 +1,77 @@ import torch +from latex2sympy2_extended import NormalizationConfig +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure +CANNOT_PARSE_GT_ANSWER = -1 +CANNOT_PARSE_PREDICTION = -2 +SUCCESS = 1 +MATCHING_FAIL = 0 + + +def verify_math_representation(completion, gt_answer): + """ + Verify if the completion is a valid math representation of the gt_answer. + """ + if not completion.startswith("\\boxed{"): + completion = "\\boxed{" + completion + "}" + if not gt_answer.startswith("\\boxed{"): + gt_answer = "\\boxed{" + gt_answer + "}" + target = ( + ExprExtractionConfig(), + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + ), + ) + if not isinstance(gt_answer, str) or len(gt_answer) == 0: + raise ValueError("gt_answer should be a string, please verify your training data.") + if not isinstance(completion, str) or len(completion) == 0: + return MATCHING_FAIL + try: + parsed_gt_answer = parse(gt_answer, extraction_config=target) + if len(parsed_gt_answer) == 0: + return CANNOT_PARSE_GT_ANSWER + parsed_completion = parse(completion, extraction_config=target) + if len(parsed_completion) == 0: + return CANNOT_PARSE_PREDICTION + if verify(parsed_gt_answer, parsed_completion): + return SUCCESS + else: + return MATCHING_FAIL + except Exception: + return MATCHING_FAIL + + +def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward): + math_verify_result = verify_math_representation(decoded_final_answer, gt_answer) + exact_match_result = ( + SUCCESS + if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "") + == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "") + else MATCHING_FAIL + ) + if math_verify_result == SUCCESS: + ans_acc += 1 + reward += acc_score + elif exact_match_result == SUCCESS: + # sometimes for answers that's not a (valid) math expression, math_verify will fail + ans_acc += 1 + if math_verify_result == CANNOT_PARSE_PREDICTION: + reward += ( + acc_score / 2 + ) # not a valid latex math representation, but the answer is correct, receive half of the score + else: + reward += acc_score + return reward, ans_acc + def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] @@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return reward + raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) @@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if ( - format_valid - and final_answer is not None - and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() - ): - ans_acc += 1 - reward += acc_score + if final_answer is not None: + if eval_mode or format_valid: + reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward) + if not eval_mode: + reward = reward + length_reward - reward = reward + length_reward + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) @@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "parsed": final_answer, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), } @@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None + if "tags" in kwargs and kwargs["tags"]: + tags = kwargs["tags"] + format_valid = format_valid and all( + [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags] + ) # Check format accuracy if format_valid: format_acc += 1 reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower(): - ans_acc += 1 - reward += acc_score + if final_answer is not None: + if eval_mode or format_valid: + reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward) + if not eval_mode: + reward = reward + length_reward + + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 - reward = reward + length_reward if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: @@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "parsed": final_answer, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), } diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 8d1f25e74..071912ddf 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -198,6 +198,8 @@ if __name__ == "__main__": "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", "reward_fn_type": args.reward_type, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, } elif args.algo == "DAPO": # DAPO variant settings @@ -213,6 +215,7 @@ if __name__ == "__main__": "loss_variation": "token_level", "soft_over_length_punishment": True, "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type, From 107470a36092996bbd2fcf564156ce3bb8416971 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Sat, 17 May 2025 21:12:58 +0800 Subject: [PATCH 2/3] fix logging rollouts --- .gitignore | 1 + .../coati/distributed/grpo_consumer.py | 16 +++--- .../ColossalChat/coati/distributed/launch.py | 4 ++ .../coati/distributed/producer.py | 54 ++++++++++++------- applications/ColossalChat/rl_example.py | 5 ++ 5 files changed, 56 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index d48035a5b..0503f3c95 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs applications/ColossalChat/wandb applications/ColossalChat/model applications/ColossalChat/eval +applications/ColossalChat/rollouts diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 8de8b774e..70e2201fe 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -120,12 +120,16 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - response_format_tags = { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } + response_format_tags = ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if grpo_config.get("reward_fn_type") == "think_answer_tags" + else None + ) reward_model_kwargs = { k: v for k, v in grpo_config.items() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index e6367d2d1..6eeb5d379 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -55,6 +55,8 @@ def launch_distributed( eval_interval: int = 100, eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -98,6 +100,8 @@ def launch_distributed( project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 0d91f43f1..75879278c 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,4 +1,5 @@ import copy +import json import os from typing import Any, Dict, Optional @@ -49,7 +50,8 @@ class BaseProducer: project_name: str = None, run_name: str = None, wandb_group_name: str = None, - wandb_log_rollout_interval: int = 20, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -70,9 +72,16 @@ class BaseProducer: self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 self.eval_mode = False - self.wandb_rollout_data = [] - self.wandb_log_rollout_interval = wandb_log_rollout_interval + self.log_rollout_interval = log_rollout_interval self.latest_rollout_log_step = -1 + if producer_idx == 0: + if os.path.exists(rollout_log_file): + raise ValueError( + f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name." + ) + else: + os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True) + self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8") if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, @@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer): project_name: str = None, run_name: str = None, wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): super().__init__( producer_idx, @@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer): project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) @@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer): def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) if self.producer_idx == 0 and not self.eval_mode: - wandb_rollout_data = self.wandb_rollout_data + [ - [ - str(self.consumer_global_step), - str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)), - ] - ] if ( - self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval or self.latest_rollout_log_step == -1 ): - self.wandb_rollout_data = wandb_rollout_data - self.latest_rollout_log_step = self.consumer_global_step - self.wandb_run.log( - { - "rollout/rollout_examples": wandb.Table( - columns=["train_step", "rollout_examples"], data=wandb_rollout_data + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } ) - } - ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step return rollouts + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 071912ddf..98c139f14 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -118,6 +118,9 @@ if __name__ == "__main__": parser.add_argument( "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." ) + parser.add_argument( + "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -269,4 +272,6 @@ if __name__ == "__main__": eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, + log_rollout_interval=20, + rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"), ) From f8bd2db33fa3b42ef8b1ddcf2c96ce6535ab672c Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 20 May 2025 09:45:56 +0800 Subject: [PATCH 3/3] add uuid to rollout log --- applications/ColossalChat/coati/distributed/launch.py | 6 +++++- applications/ColossalChat/rl_example.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 6eeb5d379..ef81bcbdd 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -56,7 +56,7 @@ def launch_distributed( eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, log_rollout_interval: int = 20, - rollout_log_file: str = "./rollout_log.jsonl", + rollout_save_dir: str = "./rollout", ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -74,6 +74,10 @@ def launch_distributed( run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" wandb_group_name = str(uuid.uuid4()) + rollout_log_file = os.path.join( + rollout_save_dir, + f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl", + ) procs = [] for i in range(num_producers): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 98c139f14..bfa0ab7d0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -273,5 +273,5 @@ if __name__ == "__main__": eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, log_rollout_interval=20, - rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"), + rollout_save_dir=args.rollout_save_dir, )