mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
remove format reward
This commit is contained in:
parent
6e71e2a3ce
commit
447ab74fb4
@ -3,14 +3,11 @@ import torch
|
|||||||
from .reward_utils import extract_solution, validate_response_structure
|
from .reward_utils import extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
soft_over_length_punishment = kwargs["soft_over_length_punishment"]
|
soft_over_length_punishment = kwargs["soft_over_length_punishment"]
|
||||||
format_score = 1.0
|
format_score = 0.0
|
||||||
acc_score = 9.0
|
acc_score = 10.0
|
||||||
if step > 30:
|
|
||||||
format_score = 0.0
|
|
||||||
acc_score = 10.0
|
|
||||||
reward = torch.tensor(0.0)
|
reward = torch.tensor(0.0)
|
||||||
format_reward = torch.tensor(0.0)
|
format_reward = torch.tensor(0.0)
|
||||||
acc_reward = torch.tensor(0.0)
|
acc_reward = torch.tensor(0.0)
|
||||||
@ -21,10 +18,8 @@ def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
max_length = kwargs.get("max_length", 1024 * 4)
|
max_length = kwargs.get("max_length", 1024 * 4)
|
||||||
cache_length = kwargs.get("cache_length", 512)
|
cache_length = kwargs.get("cache_length", 512)
|
||||||
res_length = e.item() - s.item() + 1
|
res_length = e.item() - s.item() + 1
|
||||||
if res_length >= max_length:
|
if max_length - cache_length < res_length < max_length:
|
||||||
length_reward = -1.0 * 2
|
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||||
elif res_length > max_length - cache_length:
|
|
||||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * 2
|
|
||||||
|
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return reward
|
return reward
|
||||||
|
Loading…
Reference in New Issue
Block a user