mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
fix reward score
This commit is contained in:
parent
71a0181fce
commit
abca66e69f
@ -4,6 +4,8 @@ from .reward_utils import extract_solution, validate_response_structure
|
|||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
|
format_score = 1.0
|
||||||
|
acc_score = 9.0
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
reward = torch.tensor(0.0)
|
reward = torch.tensor(0.0)
|
||||||
format_reward = torch.tensor(0.0)
|
format_reward = torch.tensor(0.0)
|
||||||
@ -20,16 +22,16 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
|
|
||||||
# Check format accuracy
|
# Check format accuracy
|
||||||
if format_valid:
|
if format_valid:
|
||||||
format_reward += 1.0
|
format_reward += format_score
|
||||||
reward += 1.0
|
reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy
|
# Check answer accuracy
|
||||||
if (
|
if (
|
||||||
final_answer is not None
|
final_answer is not None
|
||||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||||
):
|
):
|
||||||
acc_reward += 5.0
|
acc_reward += acc_score
|
||||||
reward += 5.0
|
reward += acc_score
|
||||||
|
|
||||||
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user