mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-28 04:33:28 +00:00
62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
# Code from the verl Project (https://github.com/agentica-project/rllm),
|
|
# which itself is adapted from Prime (https://github.com/PRIME-RL/PRIME)
|
|
#
|
|
# Copyright 2024 ByteDance Group
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
import traceback
|
|
from typing import Optional
|
|
|
|
from .testing_util import run_test
|
|
|
|
|
|
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
|
|
with open(os.devnull, "w") as devnull:
|
|
sys.stdout = devnull
|
|
sys.stderr = devnull
|
|
try:
|
|
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
|
|
result.append(res)
|
|
metadata_list.append(metadata)
|
|
except Exception:
|
|
# print(e) # some tracebacks are extremely long.
|
|
traceback.print_exc(10)
|
|
result.append([-1 for i in range(len(sample["inputs"]))])
|
|
metadata_list.append({})
|
|
|
|
|
|
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
|
|
"""Check correctness of code generation with a global timeout.
|
|
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
|
inside `run_test`"""
|
|
|
|
manager = multiprocessing.Manager()
|
|
result = manager.list()
|
|
metadata_list = manager.list()
|
|
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
|
|
p.start()
|
|
p.join(timeout=timeout + 1)
|
|
if p.is_alive():
|
|
p.kill()
|
|
# p.terminate()
|
|
if not result:
|
|
# consider that all tests failed
|
|
result = [[-1 for i in range(len(in_outs["inputs"]))]]
|
|
if debug:
|
|
print("global timeout")
|
|
return result[0], metadata_list
|