diff --git a/.gitignore b/.gitignore index d48035a5b..0503f3c95 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs applications/ColossalChat/wandb applications/ColossalChat/model applications/ColossalChat/eval +applications/ColossalChat/rollouts diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 531cf363c..8a2221b92 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -113,7 +113,6 @@ class BaseConsumer: ) as pbar: for step in pbar: i = 0 - allow_sync_model = True for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -140,7 +139,6 @@ class BaseConsumer: loss = self.step(i, pbar, **batch) self.buffer = self.buffer[self.dp_size * self.minibatch_size :] if loss is not None: - allow_sync_model = True pbar.set_postfix({"loss": loss}) i += 1 if self.lr_scheduler is not None: @@ -154,31 +152,29 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - if allow_sync_model: - if self.pp_size > 1: - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", ) - else: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") - torch.cuda.empty_cache() - state_dict = self.state_dict() - if self.pp_size > 1: - if self.tp_rank == 0 and self.dp_rank == 0: - ray_broadcast_tensor_dict( - state_dict, - src=self.num_producers, - device=self.device, - group_name=f"sync_model_{self.pp_rank}", - ) - else: - if self.rank == 0: - ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" - ) - del state_dict - torch.cuda.empty_cache() - allow_sync_model = True + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f73f2d96f..bc2bd45d1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -120,14 +120,20 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - response_format_tags = { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } + response_format_tags = ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if grpo_config.get("reward_fn_type") == "think_answer_tags" + else None + ) reward_model_kwargs = { - k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] } self.reward_model = VerifiableReward( reward_fns=[ diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index e6367d2d1..6eeb5d379 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -55,6 +55,8 @@ def launch_distributed( eval_interval: int = 100, eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") @@ -98,6 +100,8 @@ def launch_distributed( project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index d796bff48..75879278c 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,4 +1,5 @@ import copy +import json import os from typing import Any, Dict, Optional @@ -49,6 +50,8 @@ class BaseProducer: project_name: str = None, run_name: str = None, wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -58,7 +61,7 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size - self.lastest_eval_step = -1 + self.latest_eval_step = -1 self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -68,6 +71,17 @@ class BaseProducer: self.eval_interval = eval_interval self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 + self.eval_mode = False + self.log_rollout_interval = log_rollout_interval + self.latest_rollout_log_step = -1 + if producer_idx == 0: + if os.path.exists(rollout_log_file): + raise ValueError( + f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name." + ) + else: + os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True) + self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8") if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, @@ -77,7 +91,7 @@ class BaseProducer: group=wandb_group_name, ) - if os.path.exists(self.eval_save_dir): + if os.path.exists(self.eval_save_dir) and self.eval_interval > 0: raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") # init tokenizer @@ -180,14 +194,15 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if ( - self.consumer_global_step % self.eval_interval == 0 - and self.consumer_global_step > self.lastest_eval_step - ): + self.consumer_global_step - self.latest_eval_step >= self.eval_interval + and self.consumer_global_step > self.latest_eval_step + ) or self.latest_eval_step == -1: to_log_msg = {} + self.eval_mode = True for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( - f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}" + f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}" ) eval_results = [] eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) @@ -220,14 +235,15 @@ class BaseProducer: safe_append_to_jsonl_file( os.path.join( self.eval_save_dir, - f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl", ), eval_results, ) if self.producer_idx == 0: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) - self.lastest_eval_step = self.consumer_global_step + self.eval_mode = False + self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") @@ -256,6 +272,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) else: print( @@ -311,6 +329,8 @@ class SimpleProducer(BaseProducer): project_name: str = None, run_name: str = None, wandb_group_name: str = None, + log_rollout_interval: int = 20, + rollout_log_file: str = "./rollout_log.jsonl", ): super().__init__( producer_idx, @@ -333,6 +353,8 @@ class SimpleProducer(BaseProducer): project_name=project_name, run_name=run_name, wandb_group_name=wandb_group_name, + log_rollout_interval=log_rollout_interval, + rollout_log_file=rollout_log_file, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) @@ -343,10 +365,32 @@ class SimpleProducer(BaseProducer): @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - # if self.producer_idx == 1: - # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) - + if self.producer_idx == 0 and not self.eval_mode: + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + new_record = ( + json.dumps( + { + "train_step": self.consumer_global_step, + "rollout": self.tokenizer.batch_decode( + rollouts["input_ids"][:, 0], skip_special_tokens=True + ), + } + ) + + "\n" + ) + self.rollout_log_file.write(new_record) + self.rollout_log_file.flush() + self.latest_rollout_log_step = self.consumer_global_step return rollouts + def __del__(self): + if self.producer_idx == 0: + self.wandb_run.finish() + if hasattr(self, "rollout_log_file"): + self.rollout_log_file.close() + def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 14d340dc4..a4042ae97 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,7 +1,77 @@ import torch +from latex2sympy2_extended import NormalizationConfig +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure +CANNOT_PARSE_GT_ANSWER = -1 +CANNOT_PARSE_PREDICTION = -2 +SUCCESS = 1 +MATCHING_FAIL = 0 + + +def verify_math_representation(completion, gt_answer): + """ + Verify if the completion is a valid math representation of the gt_answer. + """ + if not completion.startswith("\\boxed{"): + completion = "\\boxed{" + completion + "}" + if not gt_answer.startswith("\\boxed{"): + gt_answer = "\\boxed{" + gt_answer + "}" + target = ( + ExprExtractionConfig(), + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + ), + ) + if not isinstance(gt_answer, str) or len(gt_answer) == 0: + raise ValueError("gt_answer should be a string, please verify your training data.") + if not isinstance(completion, str) or len(completion) == 0: + return MATCHING_FAIL + try: + parsed_gt_answer = parse(gt_answer, extraction_config=target) + if len(parsed_gt_answer) == 0: + return CANNOT_PARSE_GT_ANSWER + parsed_completion = parse(completion, extraction_config=target) + if len(parsed_completion) == 0: + return CANNOT_PARSE_PREDICTION + if verify(parsed_gt_answer, parsed_completion): + return SUCCESS + else: + return MATCHING_FAIL + except Exception: + return MATCHING_FAIL + + +def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward): + math_verify_result = verify_math_representation(decoded_final_answer, gt_answer) + exact_match_result = ( + SUCCESS + if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "") + == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "") + else MATCHING_FAIL + ) + if math_verify_result == SUCCESS: + ans_acc += 1 + reward += acc_score + elif exact_match_result == SUCCESS: + # sometimes for answers that's not a (valid) math expression, math_verify will fail + ans_acc += 1 + if math_verify_result == CANNOT_PARSE_PREDICTION: + reward += ( + acc_score / 2 + ) # not a valid latex math representation, but the answer is correct, receive half of the score + else: + reward += acc_score + return reward, ans_acc + def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] @@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return reward + raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) @@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if ( - format_valid - and final_answer is not None - and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() - ): - ans_acc += 1 - reward += acc_score + if final_answer is not None: + if eval_mode or format_valid: + reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward) + if not eval_mode: + reward = reward + length_reward - reward = reward + length_reward + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) @@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "parsed": final_answer, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), } @@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + res_length = e.item() - s.item() + 1 + if not eval_mode: + max_new_tokens = kwargs["max_new_tokens"] + else: + max_new_tokens = -1 # for eval mode, we don't need to check the length + if not eval_mode and soft_over_length_punishment: + cache_length = kwargs["cache_length"] + if max_new_tokens - cache_length < res_length < max_new_tokens: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + raise ValueError("no gt_answer is provided, please check your training dataset.") decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None + if "tags" in kwargs and kwargs["tags"]: + tags = kwargs["tags"] + format_valid = format_valid and all( + [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags] + ) # Check format accuracy if format_valid: format_acc += 1 reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower(): - ans_acc += 1 - reward += acc_score + if final_answer is not None: + if eval_mode or format_valid: + reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward) + if not eval_mode: + reward = reward + length_reward + + # Check if the sequence is over length + if not eval_mode and res_length >= max_new_tokens: + reward *= 0.0 - reward = reward + length_reward if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: @@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): "parsed": final_answer, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), + "response_length": res_length, + "reward": reward.item(), } diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 9f89eb546..a97c2b210 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -109,7 +109,7 @@ if __name__ == "__main__": "--eval-interval", type=int, default=100, - help="Interval for evaluation. Evaluate every ei consumer steps.", + help="Interval for evaluation. Evaluate every ei training steps.", ) # Logging/Checkpointing parameters @@ -118,6 +118,9 @@ if __name__ == "__main__": parser.add_argument( "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." ) + parser.add_argument( + "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -198,6 +201,8 @@ if __name__ == "__main__": "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", "reward_fn_type": args.reward_type, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, } elif args.algo == "DAPO": # DAPO variant settings @@ -213,6 +218,7 @@ if __name__ == "__main__": "loss_variation": "token_level", "soft_over_length_punishment": True, "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type, @@ -266,4 +272,6 @@ if __name__ == "__main__": eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config, + log_rollout_interval=20, + rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"), )