mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 11:58:06 +00:00
update reward fn
This commit is contained in:
parent
678f5a9eca
commit
d03cdea949
@ -11,7 +11,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
return reward
|
return reward
|
||||||
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||||
|
|
||||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||||
@ -20,7 +20,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
else:
|
else:
|
||||||
reward += 1.0
|
reward += 1.0
|
||||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||||
reward = reward + 9.0
|
reward = reward + 2.0
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user