mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-01 00:24:04 +00:00
fix code evaluation
This commit is contained in:
@@ -12,6 +12,9 @@ DEFAUT_SYSTEM_PROMPT = {
|
||||
"code": "You are a helpful assistant.",
|
||||
}
|
||||
|
||||
# bypass the proxy for local addresses
|
||||
os.environ["no_proxy"] = "127.0.0.1,localhost"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
@@ -138,6 +141,13 @@ if __name__ == "__main__":
|
||||
choices=["think_answer_tags", "boxed", "code"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-cv",
|
||||
"--code-verifier-api-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API URL for code verifier. If not provided, the code verifier will be disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ei",
|
||||
"--eval-interval",
|
||||
@@ -165,6 +175,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.train_minibatch_size is None:
|
||||
@@ -188,7 +199,7 @@ if __name__ == "__main__":
|
||||
namespace="ray-example",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
@@ -201,7 +212,7 @@ if __name__ == "__main__":
|
||||
_temp_dir=args.ray_dir,
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
@@ -321,7 +332,9 @@ if __name__ == "__main__":
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
||||
if args.reward_type == "code":
|
||||
assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type."
|
||||
grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url})
|
||||
if args.system_prompt is None:
|
||||
# Default system prompt
|
||||
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
|
||||
|
||||
Reference in New Issue
Block a user