fix code evaluation

This commit is contained in:
YeAnbang 2025-07-14 16:25:03 +08:00
parent b1f646c7e7
commit a992da9f0f
6 changed files with 104 additions and 30 deletions

View File

@ -83,7 +83,7 @@ class BaseProducer:
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
if producer_idx == 0:
@ -250,7 +250,7 @@ class BaseProducer:
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
eval_statistics_tensor[1] += len(eval_results)
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
to_log_msg[f"eval/{eval_task_name}"] = (

View File

@ -89,7 +89,7 @@ def clean_traceback(error_traceback):
return error_traceback
def run_test(in_outs, test=None, debug=False, timeout=15):
def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False):
"""
if test(generated_code) is not None it'll try to run the code.
otherwise it'll just return an input and output pair.
@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
tmp_test = new_test
sol += tmp_test
if debug:
print(f"sol = {sol}")
# if debug:
# print(f"sol = {sol}")
method_name = "code"
signal.alarm(timeout)
try:
@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
}
signal.alarm(0)
if debug:
print(f"get method = {datetime.now().time()}")
print(f"get method {method_name} = {datetime.now().time()}")
try:
method = getattr(tmp, method_name) # get_attr second arg must be str
except Exception:
@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
error_traceback = traceback.format_exc()
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
results.append(-1)
signal.alarm(0)
if run_all_tests:
continue
return results, {
"error": repr(e),
"traceback": clean_traceback(error_traceback),
@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
results.append(tmp_result)
if tmp_result is not True:
if debug:
print("final result:", results)
if run_all_tests:
continue
return results, {
"output": raw_true_output_copy,
"expected": raw_outputs,
@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
)
print(f"results = {results}")
if debug:
print("final results", results)
return results, {}

View File

@ -16,27 +16,24 @@
# limitations under the License.
import multiprocessing
import os
import sys
import traceback
from typing import Optional
import requests
from .testing_util import run_test
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
with open(os.devnull, "w") as devnull:
sys.stdout = devnull
sys.stderr = devnull
try:
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
result.append(res)
metadata_list.append(metadata)
except Exception:
# print(e) # some tracebacks are extremely long.
traceback.print_exc(10)
result.append([-1 for i in range(len(sample["inputs"]))])
metadata_list.append({})
try:
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
result.append(res)
metadata_list.append(metadata)
except Exception:
# print(e) # some tracebacks are extremely long.
traceback.print_exc(10)
result.append([-1 for i in range(len(sample["inputs"]))])
metadata_list.append({})
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
metadata_list = manager.list()
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
p.start()
p.join(timeout=timeout + 1)
p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined
if p.is_alive():
p.kill()
# p.terminate()
@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
if debug:
print("global timeout")
return result[0], metadata_list
def check_correctness_code_api(
in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness"
):
payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug}
response = requests.post(url, json=payload)
if response.status_code == 200:
results = response.json()
return results["result"], results["metadata"]
else:
print(f"Error: {response.status_code} - {response.text}")
return [-1 for i in range(len(in_outs["inputs"]))], {}

View File

@ -24,7 +24,7 @@ import torch
from latex2sympy2_extended import NormalizationConfig
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .code_reward.utils import check_correctness as check_correctness_code
from .code_reward.utils import check_correctness_code_api as check_correctness_code
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
CANNOT_PARSE_GT_ANSWER = -1
@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
url = kwargs.get("url", "http://localhost:8000/check_correctness")
tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
if format_valid:
format_acc += 1
res = []
metadata = []
try:
try:
if not isinstance(test_cases, dict):
@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
raise e
# Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
try:
res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True)
res, metadata = check_correctness_code(
in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url
)
metadata = dict(enumerate(metadata))[0]
success = all(map(lambda x: x is True, res))
success = all(map(lambda x: x == 1, res))
if success:
ans_acc += 1
if eval_mode or format_valid:
reward += acc_score
if not eval_mode:
reward = reward + length_reward
except Exception:
pass
@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
return {
"prompt": prompt,
"prediction": decoded_final_answer,
"gold": test_cases["outputs"],
"test_cases": test_cases,
"test_results": res,
"test_metadata": metadata,
"parsed": solution,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),

View File

@ -12,6 +12,9 @@ DEFAUT_SYSTEM_PROMPT = {
"code": "You are a helpful assistant.",
}
# bypass the proxy for local addresses
os.environ["no_proxy"] = "127.0.0.1,localhost"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@ -138,6 +141,13 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed", "code"],
help="Reward type for GRPO.",
)
parser.add_argument(
"-cv",
"--code-verifier-api-url",
type=str,
default=None,
help="API URL for code verifier. If not provided, the code verifier will be disabled.",
)
parser.add_argument(
"-ei",
"--eval-interval",
@ -165,6 +175,7 @@ if __name__ == "__main__":
parser.add_argument(
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@ -188,7 +199,7 @@ if __name__ == "__main__":
namespace="ray-example",
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
@ -201,7 +212,7 @@ if __name__ == "__main__":
_temp_dir=args.ray_dir,
runtime_env={
"env_vars": {
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false"
},
},
@ -321,7 +332,9 @@ if __name__ == "__main__":
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.reward_type == "code":
assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type."
grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url})
if args.system_prompt is None:
# Default system prompt
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]

View File

@ -0,0 +1,35 @@
from typing import List, Optional
from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI()
class CheckCorrectnessRequest(BaseModel):
in_outs: Optional[dict]
generation: str
timeout: int = 10
debug: bool = True
eval_mode: bool = False
class CheckCorrectnessResponse(BaseModel):
result: List[int]
metadata: List[dict]
@app.post("/check_correctness", response_model=CheckCorrectnessResponse)
def check_correctness_api(request: CheckCorrectnessRequest):
try:
result, metadata = check_correctness(
in_outs=request.in_outs,
generation=request.generation,
timeout=request.timeout,
debug=request.debug,
eval_mode=request.eval_mode,
)
return CheckCorrectnessResponse(result=result, metadata=metadata)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))