mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-30 13:31:12 +00:00
support code generation tasks
This commit is contained in:
@@ -101,7 +101,7 @@ if __name__ == "__main__":
|
||||
"--reward-type",
|
||||
type=str,
|
||||
default="think_answer_tags",
|
||||
choices=["think_answer_tags", "boxed"],
|
||||
choices=["think_answer_tags", "boxed", "code"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -167,10 +167,29 @@ if __name__ == "__main__":
|
||||
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
ray.init(
|
||||
address="local",
|
||||
namespace="ray-example",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
||||
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||
ray.init(
|
||||
_node_ip_address=args.master_address,
|
||||
namespace="ray-example",
|
||||
_temp_dir=args.ray_dir,
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if args.top_k is None:
|
||||
if args.backend == "transformers":
|
||||
@@ -178,6 +197,8 @@ if __name__ == "__main__":
|
||||
elif args.backend == "vllm":
|
||||
args.top_k = -1
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
@@ -288,7 +309,6 @@ if __name__ == "__main__":
|
||||
"max_length": args.max_prompt_tokens,
|
||||
"system_prompt": args.system_prompt,
|
||||
},
|
||||
dataloaders_config={},
|
||||
inference_model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
num_generations=args.num_generations,
|
||||
|
||||
Reference in New Issue
Block a user