From c5e97f4e25b68371cd4c3e52693eb517440b712a Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 14 Jul 2025 16:25:03 +0800 Subject: [PATCH] fix code evaluation --- .../coati/distributed/producer.py | 4 +- .../reward/code_reward/testing_util.py | 19 ++++++--- .../distributed/reward/code_reward/utils.py | 40 ++++++++++++------- .../coati/distributed/reward/reward_fn.py | 17 ++++++-- applications/ColossalChat/rl_example.py | 19 +++++++-- .../ColossalChat/start_code_verifier.py | 35 ++++++++++++++++ 6 files changed, 104 insertions(+), 30 deletions(-) create mode 100644 applications/ColossalChat/start_code_verifier.py diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2a3746391..11fb5d3aa 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -83,7 +83,7 @@ class BaseProducer: reward_model_kwargs = { k: v for k, v in grpo_config.items() - if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"] } self.response_format_tags = grpo_config.get("response_format_tags", None) if producer_idx == 0: @@ -250,7 +250,7 @@ class BaseProducer: for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) ] - eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results]) eval_statistics_tensor[1] += len(eval_results) allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") to_log_msg[f"eval/{eval_task_name}"] = ( diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py index c6d6d8fad..12973337e 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py @@ -89,7 +89,7 @@ def clean_traceback(error_traceback): return error_traceback -def run_test(in_outs, test=None, debug=False, timeout=15): +def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False): """ if test(generated_code) is not None it'll try to run the code. otherwise it'll just return an input and output pair. @@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): tmp_test = new_test sol += tmp_test - if debug: - print(f"sol = {sol}") + # if debug: + # print(f"sol = {sol}") method_name = "code" signal.alarm(timeout) try: @@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15): } signal.alarm(0) if debug: - print(f"get method = {datetime.now().time()}") - + print(f"get method {method_name} = {datetime.now().time()}") try: method = getattr(tmp, method_name) # get_attr second arg must be str except Exception: @@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): error_traceback = traceback.format_exc() print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") results.append(-1) + signal.alarm(0) + if run_all_tests: + continue return results, { "error": repr(e), "traceback": clean_traceback(error_traceback), @@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): results.append(tmp_result) if tmp_result is not True: + if debug: + print("final result:", results) + if run_all_tests: + continue return results, { "output": raw_true_output_copy, "expected": raw_outputs, @@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15): ) print(f"results = {results}") - + if debug: + print("final results", results) return results, {} diff --git a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py index e19e9e387..dbb3b2149 100644 --- a/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py +++ b/applications/ColossalChat/coati/distributed/reward/code_reward/utils.py @@ -16,27 +16,24 @@ # limitations under the License. import multiprocessing -import os -import sys import traceback from typing import Optional +import requests + from .testing_util import run_test def _temp_run(sample, generation, debug, result, metadata_list, timeout): - with open(os.devnull, "w") as devnull: - sys.stdout = devnull - sys.stderr = devnull - try: - res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) - result.append(res) - metadata_list.append(metadata) - except Exception: - # print(e) # some tracebacks are extremely long. - traceback.print_exc(10) - result.append([-1 for i in range(len(sample["inputs"]))]) - metadata_list.append({}) + try: + res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) + result.append(res) + metadata_list.append(metadata) + except Exception: + # print(e) # some tracebacks are extremely long. + traceback.print_exc(10) + result.append([-1 for i in range(len(sample["inputs"]))]) + metadata_list.append({}) def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): @@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru metadata_list = manager.list() p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) p.start() - p.join(timeout=timeout + 1) + p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined if p.is_alive(): p.kill() # p.terminate() @@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru if debug: print("global timeout") return result[0], metadata_list + + +def check_correctness_code_api( + in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness" +): + payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug} + response = requests.post(url, json=payload) + if response.status_code == 200: + results = response.json() + return results["result"], results["metadata"] + else: + print(f"Error: {response.status_code} - {response.text}") + return [-1 for i in range(len(in_outs["inputs"]))], {} diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 5c8810fdf..f7a2fb89c 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -24,7 +24,7 @@ import torch from latex2sympy2_extended import NormalizationConfig from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify -from .code_reward.utils import check_correctness as check_correctness_code +from .code_reward.utils import check_correctness_code_api as check_correctness_code from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure CANNOT_PARSE_GT_ANSWER = -1 @@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): + url = kwargs.get("url", "http://localhost:8000/check_correctness") tokenizer = kwargs["tokenizer"] eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) @@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): if format_valid: format_acc += 1 + res = [] + metadata = [] + try: try: if not isinstance(test_cases, dict): @@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): raise e # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. try: - res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True) + res, metadata = check_correctness_code( + in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url + ) metadata = dict(enumerate(metadata))[0] - success = all(map(lambda x: x is True, res)) + success = all(map(lambda x: x == 1, res)) if success: ans_acc += 1 if eval_mode or format_valid: reward += acc_score if not eval_mode: reward = reward + length_reward + except Exception: pass @@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs): return { "prompt": prompt, "prediction": decoded_final_answer, - "gold": test_cases["outputs"], + "test_cases": test_cases, + "test_results": res, + "test_metadata": metadata, "parsed": solution, "format_valid": format_acc.item(), "ans_valid": ans_acc.item(), diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index d7b7c2a5d..08814f9f1 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -12,6 +12,9 @@ DEFAUT_SYSTEM_PROMPT = { "code": "You are a helpful assistant.", } +# bypass the proxy for local addresses +os.environ["no_proxy"] = "127.0.0.1,localhost" + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") @@ -138,6 +141,13 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed", "code"], help="Reward type for GRPO.", ) + parser.add_argument( + "-cv", + "--code-verifier-api-url", + type=str, + default=None, + help="API URL for code verifier. If not provided, the code verifier will be disabled.", + ) parser.add_argument( "-ei", "--eval-interval", @@ -165,6 +175,7 @@ if __name__ == "__main__": parser.add_argument( "--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process." ) + args = parser.parse_args() if args.train_minibatch_size is None: @@ -188,7 +199,7 @@ if __name__ == "__main__": namespace="ray-example", runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -201,7 +212,7 @@ if __name__ == "__main__": _temp_dir=args.ray_dir, runtime_env={ "env_vars": { - # "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray + # "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray "TOKENIZERS_PARALLELISM": "false" }, }, @@ -321,7 +332,9 @@ if __name__ == "__main__": } else: raise ValueError(f"Unsupported algorithm: {args.algo}") - + if args.reward_type == "code": + assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type." + grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url}) if args.system_prompt is None: # Default system prompt args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type] diff --git a/applications/ColossalChat/start_code_verifier.py b/applications/ColossalChat/start_code_verifier.py new file mode 100644 index 000000000..d1924f610 --- /dev/null +++ b/applications/ColossalChat/start_code_verifier.py @@ -0,0 +1,35 @@ +from typing import List, Optional + +from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +app = FastAPI() + + +class CheckCorrectnessRequest(BaseModel): + in_outs: Optional[dict] + generation: str + timeout: int = 10 + debug: bool = True + eval_mode: bool = False + + +class CheckCorrectnessResponse(BaseModel): + result: List[int] + metadata: List[dict] + + +@app.post("/check_correctness", response_model=CheckCorrectnessResponse) +def check_correctness_api(request: CheckCorrectnessRequest): + try: + result, metadata = check_correctness( + in_outs=request.in_outs, + generation=request.generation, + timeout=request.timeout, + debug=request.debug, + eval_mode=request.eval_mode, + ) + return CheckCorrectnessResponse(result=result, metadata=metadata) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e))