mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
fix code evaluation
This commit is contained in:
parent
509274c47e
commit
c5e97f4e25
@ -83,7 +83,7 @@ class BaseProducer:
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
|
||||
}
|
||||
self.response_format_tags = grpo_config.get("response_format_tags", None)
|
||||
if producer_idx == 0:
|
||||
@ -250,7 +250,7 @@ class BaseProducer:
|
||||
for m in range(eval_outputs["input_ids"].size(0))
|
||||
for n in range(eval_outputs["input_ids"].size(1))
|
||||
]
|
||||
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
|
||||
eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
|
||||
eval_statistics_tensor[1] += len(eval_results)
|
||||
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
||||
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||
|
@ -89,7 +89,7 @@ def clean_traceback(error_traceback):
|
||||
return error_traceback
|
||||
|
||||
|
||||
def run_test(in_outs, test=None, debug=False, timeout=15):
|
||||
def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False):
|
||||
"""
|
||||
if test(generated_code) is not None it'll try to run the code.
|
||||
otherwise it'll just return an input and output pair.
|
||||
@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
||||
tmp_test = new_test
|
||||
|
||||
sol += tmp_test
|
||||
if debug:
|
||||
print(f"sol = {sol}")
|
||||
# if debug:
|
||||
# print(f"sol = {sol}")
|
||||
method_name = "code"
|
||||
signal.alarm(timeout)
|
||||
try:
|
||||
@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
||||
}
|
||||
signal.alarm(0)
|
||||
if debug:
|
||||
print(f"get method = {datetime.now().time()}")
|
||||
|
||||
print(f"get method {method_name} = {datetime.now().time()}")
|
||||
try:
|
||||
method = getattr(tmp, method_name) # get_attr second arg must be str
|
||||
except Exception:
|
||||
@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
||||
error_traceback = traceback.format_exc()
|
||||
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
|
||||
results.append(-1)
|
||||
signal.alarm(0)
|
||||
if run_all_tests:
|
||||
continue
|
||||
return results, {
|
||||
"error": repr(e),
|
||||
"traceback": clean_traceback(error_traceback),
|
||||
@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
||||
|
||||
results.append(tmp_result)
|
||||
if tmp_result is not True:
|
||||
if debug:
|
||||
print("final result:", results)
|
||||
if run_all_tests:
|
||||
continue
|
||||
return results, {
|
||||
"output": raw_true_output_copy,
|
||||
"expected": raw_outputs,
|
||||
@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
||||
)
|
||||
|
||||
print(f"results = {results}")
|
||||
|
||||
if debug:
|
||||
print("final results", results)
|
||||
return results, {}
|
||||
|
||||
|
||||
|
@ -16,18 +16,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .testing_util import run_test
|
||||
|
||||
|
||||
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
|
||||
with open(os.devnull, "w") as devnull:
|
||||
sys.stdout = devnull
|
||||
sys.stderr = devnull
|
||||
try:
|
||||
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
|
||||
result.append(res)
|
||||
@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
|
||||
metadata_list = manager.list()
|
||||
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
|
||||
p.start()
|
||||
p.join(timeout=timeout + 1)
|
||||
p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
# p.terminate()
|
||||
@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
|
||||
if debug:
|
||||
print("global timeout")
|
||||
return result[0], metadata_list
|
||||
|
||||
|
||||
def check_correctness_code_api(
|
||||
in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness"
|
||||
):
|
||||
payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug}
|
||||
response = requests.post(url, json=payload)
|
||||
if response.status_code == 200:
|
||||
results = response.json()
|
||||
return results["result"], results["metadata"]
|
||||
else:
|
||||
print(f"Error: {response.status_code} - {response.text}")
|
||||
return [-1 for i in range(len(in_outs["inputs"]))], {}
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||
|
||||
from .code_reward.utils import check_correctness as check_correctness_code
|
||||
from .code_reward.utils import check_correctness_code_api as check_correctness_code
|
||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||
|
||||
CANNOT_PARSE_GT_ANSWER = -1
|
||||
@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
|
||||
|
||||
def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
||||
url = kwargs.get("url", "http://localhost:8000/check_correctness")
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
eval_mode = kwargs.get("eval_mode", False)
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
|
||||
res = []
|
||||
metadata = []
|
||||
|
||||
try:
|
||||
try:
|
||||
if not isinstance(test_cases, dict):
|
||||
@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
||||
raise e
|
||||
# Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
|
||||
try:
|
||||
res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True)
|
||||
res, metadata = check_correctness_code(
|
||||
in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url
|
||||
)
|
||||
metadata = dict(enumerate(metadata))[0]
|
||||
success = all(map(lambda x: x is True, res))
|
||||
success = all(map(lambda x: x == 1, res))
|
||||
if success:
|
||||
ans_acc += 1
|
||||
if eval_mode or format_valid:
|
||||
reward += acc_score
|
||||
if not eval_mode:
|
||||
reward = reward + length_reward
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"prediction": decoded_final_answer,
|
||||
"gold": test_cases["outputs"],
|
||||
"test_cases": test_cases,
|
||||
"test_results": res,
|
||||
"test_metadata": metadata,
|
||||
"parsed": solution,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
|
@ -12,6 +12,9 @@ DEFAUT_SYSTEM_PROMPT = {
|
||||
"code": "You are a helpful assistant.",
|
||||
}
|
||||
|
||||
# bypass the proxy for local addresses
|
||||
os.environ["no_proxy"] = "127.0.0.1,localhost"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
@ -138,6 +141,13 @@ if __name__ == "__main__":
|
||||
choices=["think_answer_tags", "boxed", "code"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-cv",
|
||||
"--code-verifier-api-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API URL for code verifier. If not provided, the code verifier will be disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ei",
|
||||
"--eval-interval",
|
||||
@ -165,6 +175,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.train_minibatch_size is None:
|
||||
@ -188,7 +199,7 @@ if __name__ == "__main__":
|
||||
namespace="ray-example",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
@ -201,7 +212,7 @@ if __name__ == "__main__":
|
||||
_temp_dir=args.ray_dir,
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
@ -321,7 +332,9 @@ if __name__ == "__main__":
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
||||
if args.reward_type == "code":
|
||||
assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type."
|
||||
grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url})
|
||||
if args.system_prompt is None:
|
||||
# Default system prompt
|
||||
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
|
||||
|
35
applications/ColossalChat/start_code_verifier.py
Normal file
35
applications/ColossalChat/start_code_verifier.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class CheckCorrectnessRequest(BaseModel):
|
||||
in_outs: Optional[dict]
|
||||
generation: str
|
||||
timeout: int = 10
|
||||
debug: bool = True
|
||||
eval_mode: bool = False
|
||||
|
||||
|
||||
class CheckCorrectnessResponse(BaseModel):
|
||||
result: List[int]
|
||||
metadata: List[dict]
|
||||
|
||||
|
||||
@app.post("/check_correctness", response_model=CheckCorrectnessResponse)
|
||||
def check_correctness_api(request: CheckCorrectnessRequest):
|
||||
try:
|
||||
result, metadata = check_correctness(
|
||||
in_outs=request.in_outs,
|
||||
generation=request.generation,
|
||||
timeout=request.timeout,
|
||||
debug=request.debug,
|
||||
eval_mode=request.eval_mode,
|
||||
)
|
||||
return CheckCorrectnessResponse(result=result, metadata=metadata)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
Loading…
Reference in New Issue
Block a user