mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-24 20:20:53 +00:00
fix code evaluation
This commit is contained in:
parent
509274c47e
commit
c5e97f4e25
@ -83,7 +83,7 @@ class BaseProducer:
|
|||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in grpo_config.items()
|
for k, v in grpo_config.items()
|
||||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
|
||||||
}
|
}
|
||||||
self.response_format_tags = grpo_config.get("response_format_tags", None)
|
self.response_format_tags = grpo_config.get("response_format_tags", None)
|
||||||
if producer_idx == 0:
|
if producer_idx == 0:
|
||||||
@ -250,7 +250,7 @@ class BaseProducer:
|
|||||||
for m in range(eval_outputs["input_ids"].size(0))
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
for n in range(eval_outputs["input_ids"].size(1))
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
]
|
]
|
||||||
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
|
eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
|
||||||
eval_statistics_tensor[1] += len(eval_results)
|
eval_statistics_tensor[1] += len(eval_results)
|
||||||
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
||||||
to_log_msg[f"eval/{eval_task_name}"] = (
|
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||||
|
@ -89,7 +89,7 @@ def clean_traceback(error_traceback):
|
|||||||
return error_traceback
|
return error_traceback
|
||||||
|
|
||||||
|
|
||||||
def run_test(in_outs, test=None, debug=False, timeout=15):
|
def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False):
|
||||||
"""
|
"""
|
||||||
if test(generated_code) is not None it'll try to run the code.
|
if test(generated_code) is not None it'll try to run the code.
|
||||||
otherwise it'll just return an input and output pair.
|
otherwise it'll just return an input and output pair.
|
||||||
@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
|||||||
tmp_test = new_test
|
tmp_test = new_test
|
||||||
|
|
||||||
sol += tmp_test
|
sol += tmp_test
|
||||||
if debug:
|
# if debug:
|
||||||
print(f"sol = {sol}")
|
# print(f"sol = {sol}")
|
||||||
method_name = "code"
|
method_name = "code"
|
||||||
signal.alarm(timeout)
|
signal.alarm(timeout)
|
||||||
try:
|
try:
|
||||||
@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
|||||||
}
|
}
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
if debug:
|
if debug:
|
||||||
print(f"get method = {datetime.now().time()}")
|
print(f"get method {method_name} = {datetime.now().time()}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
method = getattr(tmp, method_name) # get_attr second arg must be str
|
method = getattr(tmp, method_name) # get_attr second arg must be str
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
|||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
|
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
|
||||||
results.append(-1)
|
results.append(-1)
|
||||||
|
signal.alarm(0)
|
||||||
|
if run_all_tests:
|
||||||
|
continue
|
||||||
return results, {
|
return results, {
|
||||||
"error": repr(e),
|
"error": repr(e),
|
||||||
"traceback": clean_traceback(error_traceback),
|
"traceback": clean_traceback(error_traceback),
|
||||||
@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
|||||||
|
|
||||||
results.append(tmp_result)
|
results.append(tmp_result)
|
||||||
if tmp_result is not True:
|
if tmp_result is not True:
|
||||||
|
if debug:
|
||||||
|
print("final result:", results)
|
||||||
|
if run_all_tests:
|
||||||
|
continue
|
||||||
return results, {
|
return results, {
|
||||||
"output": raw_true_output_copy,
|
"output": raw_true_output_copy,
|
||||||
"expected": raw_outputs,
|
"expected": raw_outputs,
|
||||||
@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"results = {results}")
|
print(f"results = {results}")
|
||||||
|
if debug:
|
||||||
|
print("final results", results)
|
||||||
return results, {}
|
return results, {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,27 +16,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from .testing_util import run_test
|
from .testing_util import run_test
|
||||||
|
|
||||||
|
|
||||||
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
|
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
|
||||||
with open(os.devnull, "w") as devnull:
|
try:
|
||||||
sys.stdout = devnull
|
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
|
||||||
sys.stderr = devnull
|
result.append(res)
|
||||||
try:
|
metadata_list.append(metadata)
|
||||||
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
|
except Exception:
|
||||||
result.append(res)
|
# print(e) # some tracebacks are extremely long.
|
||||||
metadata_list.append(metadata)
|
traceback.print_exc(10)
|
||||||
except Exception:
|
result.append([-1 for i in range(len(sample["inputs"]))])
|
||||||
# print(e) # some tracebacks are extremely long.
|
metadata_list.append({})
|
||||||
traceback.print_exc(10)
|
|
||||||
result.append([-1 for i in range(len(sample["inputs"]))])
|
|
||||||
metadata_list.append({})
|
|
||||||
|
|
||||||
|
|
||||||
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
|
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
|
||||||
@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
|
|||||||
metadata_list = manager.list()
|
metadata_list = manager.list()
|
||||||
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
|
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
|
||||||
p.start()
|
p.start()
|
||||||
p.join(timeout=timeout + 1)
|
p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined
|
||||||
if p.is_alive():
|
if p.is_alive():
|
||||||
p.kill()
|
p.kill()
|
||||||
# p.terminate()
|
# p.terminate()
|
||||||
@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
|
|||||||
if debug:
|
if debug:
|
||||||
print("global timeout")
|
print("global timeout")
|
||||||
return result[0], metadata_list
|
return result[0], metadata_list
|
||||||
|
|
||||||
|
|
||||||
|
def check_correctness_code_api(
|
||||||
|
in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness"
|
||||||
|
):
|
||||||
|
payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug}
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
results = response.json()
|
||||||
|
return results["result"], results["metadata"]
|
||||||
|
else:
|
||||||
|
print(f"Error: {response.status_code} - {response.text}")
|
||||||
|
return [-1 for i in range(len(in_outs["inputs"]))], {}
|
||||||
|
@ -24,7 +24,7 @@ import torch
|
|||||||
from latex2sympy2_extended import NormalizationConfig
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||||
|
|
||||||
from .code_reward.utils import check_correctness as check_correctness_code
|
from .code_reward.utils import check_correctness_code_api as check_correctness_code
|
||||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||||
|
|
||||||
CANNOT_PARSE_GT_ANSWER = -1
|
CANNOT_PARSE_GT_ANSWER = -1
|
||||||
@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
||||||
|
url = kwargs.get("url", "http://localhost:8000/check_correctness")
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
eval_mode = kwargs.get("eval_mode", False)
|
eval_mode = kwargs.get("eval_mode", False)
|
||||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
|||||||
if format_valid:
|
if format_valid:
|
||||||
format_acc += 1
|
format_acc += 1
|
||||||
|
|
||||||
|
res = []
|
||||||
|
metadata = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
if not isinstance(test_cases, dict):
|
if not isinstance(test_cases, dict):
|
||||||
@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
|||||||
raise e
|
raise e
|
||||||
# Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
|
# Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
|
||||||
try:
|
try:
|
||||||
res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True)
|
res, metadata = check_correctness_code(
|
||||||
|
in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url
|
||||||
|
)
|
||||||
metadata = dict(enumerate(metadata))[0]
|
metadata = dict(enumerate(metadata))[0]
|
||||||
success = all(map(lambda x: x is True, res))
|
success = all(map(lambda x: x == 1, res))
|
||||||
if success:
|
if success:
|
||||||
ans_acc += 1
|
ans_acc += 1
|
||||||
if eval_mode or format_valid:
|
if eval_mode or format_valid:
|
||||||
reward += acc_score
|
reward += acc_score
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
|
|||||||
return {
|
return {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"prediction": decoded_final_answer,
|
"prediction": decoded_final_answer,
|
||||||
"gold": test_cases["outputs"],
|
"test_cases": test_cases,
|
||||||
|
"test_results": res,
|
||||||
|
"test_metadata": metadata,
|
||||||
"parsed": solution,
|
"parsed": solution,
|
||||||
"format_valid": format_acc.item(),
|
"format_valid": format_acc.item(),
|
||||||
"ans_valid": ans_acc.item(),
|
"ans_valid": ans_acc.item(),
|
||||||
|
@ -12,6 +12,9 @@ DEFAUT_SYSTEM_PROMPT = {
|
|||||||
"code": "You are a helpful assistant.",
|
"code": "You are a helpful assistant.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# bypass the proxy for local addresses
|
||||||
|
os.environ["no_proxy"] = "127.0.0.1,localhost"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
@ -138,6 +141,13 @@ if __name__ == "__main__":
|
|||||||
choices=["think_answer_tags", "boxed", "code"],
|
choices=["think_answer_tags", "boxed", "code"],
|
||||||
help="Reward type for GRPO.",
|
help="Reward type for GRPO.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-cv",
|
||||||
|
"--code-verifier-api-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API URL for code verifier. If not provided, the code verifier will be disabled.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-ei",
|
"-ei",
|
||||||
"--eval-interval",
|
"--eval-interval",
|
||||||
@ -165,6 +175,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -188,7 +199,7 @@ if __name__ == "__main__":
|
|||||||
namespace="ray-example",
|
namespace="ray-example",
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {
|
||||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||||
"TOKENIZERS_PARALLELISM": "false"
|
"TOKENIZERS_PARALLELISM": "false"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -201,7 +212,7 @@ if __name__ == "__main__":
|
|||||||
_temp_dir=args.ray_dir,
|
_temp_dir=args.ray_dir,
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {
|
||||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||||
"TOKENIZERS_PARALLELISM": "false"
|
"TOKENIZERS_PARALLELISM": "false"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -321,7 +332,9 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
|
if args.reward_type == "code":
|
||||||
|
assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type."
|
||||||
|
grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url})
|
||||||
if args.system_prompt is None:
|
if args.system_prompt is None:
|
||||||
# Default system prompt
|
# Default system prompt
|
||||||
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
|
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
|
||||||
|
35
applications/ColossalChat/start_code_verifier.py
Normal file
35
applications/ColossalChat/start_code_verifier.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class CheckCorrectnessRequest(BaseModel):
|
||||||
|
in_outs: Optional[dict]
|
||||||
|
generation: str
|
||||||
|
timeout: int = 10
|
||||||
|
debug: bool = True
|
||||||
|
eval_mode: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class CheckCorrectnessResponse(BaseModel):
|
||||||
|
result: List[int]
|
||||||
|
metadata: List[dict]
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/check_correctness", response_model=CheckCorrectnessResponse)
|
||||||
|
def check_correctness_api(request: CheckCorrectnessRequest):
|
||||||
|
try:
|
||||||
|
result, metadata = check_correctness(
|
||||||
|
in_outs=request.in_outs,
|
||||||
|
generation=request.generation,
|
||||||
|
timeout=request.timeout,
|
||||||
|
debug=request.debug,
|
||||||
|
eval_mode=request.eval_mode,
|
||||||
|
)
|
||||||
|
return CheckCorrectnessResponse(result=result, metadata=metadata)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
Loading…
Reference in New Issue
Block a user