fix reward

This commit is contained in:
YeAnbang 2025-05-05 15:40:22 +08:00
parent d4a6b6c4a7
commit 6fff36dd63
4 changed files with 30 additions and 17 deletions

View File

@ -127,7 +127,9 @@ class GRPOConsumer(BaseConsumer):
) )
# Initialize verifiable reward. # Initialize verifiable reward.
reward_model_kwargs = { reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
} }
self.reward_model = VerifiableReward( self.reward_model = VerifiableReward(
reward_fns=[ reward_fns=[

View File

@ -93,7 +93,7 @@ class BaseProducer:
) )
self.eval_dataset_config = eval_dataset_config self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None: if self.eval_dataset_config is not None and self.eval_interval > 0:
self.eval_dataloaders = {} self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config: for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")

View File

@ -81,12 +81,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
s, e = response_idx[0], response_idx[1] s, e = response_idx[0], response_idx[1]
length_reward = 0.0 length_reward = 0.0
if soft_over_length_punishment: max_new_tokens = kwargs["max_new_tokens"]
max_length = kwargs.get("max_length", 1024 * 4) res_length = e.item() - s.item() + 1
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None: if gt_answer is None:
return reward return reward
@ -105,7 +101,16 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
if format_valid and final_answer is not None: if format_valid and final_answer is not None:
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
reward = reward + length_reward if soft_over_length_punishment:
cache_length = kwargs.get("cache_length", 512)
if max_new_tokens - cache_length < res_length:
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
reward = reward + length_reward
if res_length >= max_new_tokens:
# no reward for over length
print(f"Overlength response detected: res_len: {e.item()-s.item()+1}, limit:{max_new_tokens}")
reward *= 0.0
format_acc *= 0.0
if not eval_mode: if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
@ -133,12 +138,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
s, e = response_idx[0], response_idx[1] s, e = response_idx[0], response_idx[1]
length_reward = 0.0 length_reward = 0.0
if soft_over_length_punishment: max_new_tokens = kwargs["max_new_tokens"]
max_length = kwargs.get("max_length", 1024 * 4) res_length = e.item() - s.item() + 1
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None: if gt_answer is None:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
@ -161,8 +162,17 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None: if format_valid and final_answer is not None:
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
if soft_over_length_punishment:
cache_length = kwargs.get("cache_length", 512)
if max_new_tokens - cache_length < res_length:
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
reward = reward + length_reward
if res_length >= max_new_tokens:
# no reward for over length
print(f"Overlength response detected: res_len: {e.item()-s.item()+1}, limit:{max_new_tokens}")
reward *= 0.0
format_acc *= 0.0
reward = reward + length_reward
if not eval_mode: if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else: else:

View File

@ -199,6 +199,7 @@ if __name__ == "__main__":
"beta": args.kl_coeff, # KL penalty coefficient "beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level", "loss_variation": "sample_level",
"reward_fn_type": args.reward_type, "reward_fn_type": args.reward_type,
"max_new_tokens": args.max_new_tokens,
} }
elif args.algo == "DAPO": elif args.algo == "DAPO":
# DAPO variant settings # DAPO variant settings
@ -213,7 +214,7 @@ if __name__ == "__main__":
"beta": 0, # no KL penalty for DAPO "beta": 0, # no KL penalty for DAPO
"loss_variation": "token_level", "loss_variation": "token_level",
"soft_over_length_punishment": True, "soft_over_length_punishment": True,
"max_length": args.max_new_tokens + args.max_prompt_tokens, "max_new_tokens": args.max_new_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)), "cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True, "filter_truncated_response": True,
"reward_fn_type": args.reward_type, "reward_fn_type": args.reward_type,