diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5dde66435..05ed04270 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -127,7 +127,9 @@ class GRPOConsumer(BaseConsumer): ) # Initialize verifiable reward. reward_model_kwargs = { - k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] } self.reward_model = VerifiableReward( reward_fns=[ diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 8766daf1a..4f87bd3d5 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -93,7 +93,7 @@ class BaseProducer: ) self.eval_dataset_config = eval_dataset_config - if self.eval_dataset_config is not None: + if self.eval_dataset_config is not None and self.eval_interval > 0: self.eval_dataloaders = {} for eval_task_name in self.eval_dataset_config: eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 607b5eefc..ec72de4e1 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -81,12 +81,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + max_new_tokens = kwargs["max_new_tokens"] + res_length = e.item() - s.item() + 1 if gt_answer is None: return reward @@ -105,7 +101,16 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): if format_valid and final_answer is not None: reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) - reward = reward + length_reward + if soft_over_length_punishment: + cache_length = kwargs.get("cache_length", 512) + if max_new_tokens - cache_length < res_length: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score + reward = reward + length_reward + if res_length >= max_new_tokens: + # no reward for over length + print(f"Overlength response detected: res_len: {e.item()-s.item()+1}, limit:{max_new_tokens}") + reward *= 0.0 + format_acc *= 0.0 if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) @@ -133,12 +138,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): s, e = response_idx[0], response_idx[1] length_reward = 0.0 - if soft_over_length_punishment: - max_length = kwargs.get("max_length", 1024 * 4) - cache_length = kwargs.get("cache_length", 512) - res_length = e.item() - s.item() + 1 - if max_length - cache_length < res_length < max_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + max_new_tokens = kwargs["max_new_tokens"] + res_length = e.item() - s.item() + 1 if gt_answer is None: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) @@ -161,8 +162,17 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if format_valid and final_answer is not None: reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) + if soft_over_length_punishment: + cache_length = kwargs.get("cache_length", 512) + if max_new_tokens - cache_length < res_length: + length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score + reward = reward + length_reward + if res_length >= max_new_tokens: + # no reward for over length + print(f"Overlength response detected: res_len: {e.item()-s.item()+1}, limit:{max_new_tokens}") + reward *= 0.0 + format_acc *= 0.0 - reward = reward + length_reward if not eval_mode: return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index dc503459e..4f081d0f3 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -199,6 +199,7 @@ if __name__ == "__main__": "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", "reward_fn_type": args.reward_type, + "max_new_tokens": args.max_new_tokens, } elif args.algo == "DAPO": # DAPO variant settings @@ -213,7 +214,7 @@ if __name__ == "__main__": "beta": 0, # no KL penalty for DAPO "loss_variation": "token_level", "soft_over_length_punishment": True, - "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type,