From 8e6c9a4ab3f10a22ee5d1d3001ea64d0bda5191a Mon Sep 17 00:00:00 2001 From: Tong Li Date: Sun, 23 Feb 2025 11:02:54 +0800 Subject: [PATCH] add reward related function --- .../coati/distributed/reward/reward_fn.py | 51 +++++++++++++ .../coati/distributed/reward/reward_utils.py | 76 +++++++++++++++++++ .../distributed/reward/verifiable_reward.py | 47 ++++++++++++ 3 files changed, 174 insertions(+) create mode 100644 applications/ColossalChat/coati/distributed/reward/reward_fn.py create mode 100644 applications/ColossalChat/coati/distributed/reward/reward_utils.py create mode 100644 applications/ColossalChat/coati/distributed/reward/verifiable_reward.py diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py new file mode 100644 index 000000000..f127a1ece --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -0,0 +1,51 @@ +import torch + +from .reward_utils import extract_solution, validate_response_structure + + +def math_reward_fn(input_ids, **kwargs): + # apply varifiable reward + # reward 10 points if the final answer is correct, reward 1 point if format is correct + + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward + + +def gsm8k_reward_fn(input_ids, **kwargs): + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + is_valid = True + try: + int(final_answer.strip()) + except Exception: + is_valid = False + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not is_valid or not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward diff --git a/applications/ColossalChat/coati/distributed/reward/reward_utils.py b/applications/ColossalChat/coati/distributed/reward/reward_utils.py new file mode 100644 index 000000000..c1e73d4b9 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/reward_utils.py @@ -0,0 +1,76 @@ +# Copyright Unakar +# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Optional, Tuple + + +def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + validation_passed = True + # Check required tags + if tags is None: + tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + positions = {} + for tag_name, tag_info in tags.items(): + tag_str = tag_info["text"] + expected_count = tag_info["num_occur"] + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + if count != expected_count: + validation_passed = False + # Verify tag order + if ( + positions["think_start"] > positions["think_end"] + or positions["think_end"] > positions["answer_start"] + or positions["answer_start"] > positions["answer_end"] + ): + validation_passed = False + if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]): + validation_passed = False + return validation_passed + + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + + # Extract final answer using XML-style tags + answer_pattern = r"(.*?)" + matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL)) + + if not matches: + return None, solution_str + + final_answer = matches[-1].group(1).strip() + return final_answer, solution_str diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py new file mode 100644 index 000000000..d1700d86f --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -0,0 +1,47 @@ +""" +Function-based reward verification module. +""" + +from typing import Any, Dict, List + +import torch + + +class VerifiableReward: + def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]): + self.reward_fn = reward_fn + self.reward_args = reward_args + + def __call__( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + response_start: List[int] = None, + response_end: List[int] = None, + gt_answer: List[str] = None, + ) -> torch.Tensor: + # Get batch size + bs = input_ids.size(0) + # Initialize reward + reward = torch.zeros(bs, device=input_ids.device) + + # Loop through reward functions + for reward_fn in self.reward_fn_list: + # Apply the reward function to the entire batch at once + reward_batch = torch.stack( + [ + reward_fn( + input_ids[i], + attention_mask[i], + response_start=response_start[i], + response_end=response_end[i], + gt_answer=gt_answer[i], + **self.kwargs, + ) + for i in range(bs) + ], + dim=0, + ) + + rewards += reward_batch + return rewards