From 96faf5454245d001d06ef538a4fd033d87ca62de Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 5 Jun 2025 15:41:14 +0800 Subject: [PATCH] fix typ and parameter description --- applications/ColossalChat/rl_example.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index eed0af362..b4f565617 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -126,25 +126,25 @@ if __name__ == "__main__": "--tensor-parallel-size", type=int, default=1, - help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.", + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", ) parser.add_argument( "-pp", "--pipeline-parallel-size", type=int, default=1, - help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.", + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", ) parser.add_argument( "-zero", "--zero-stage", type=int, default=0, - help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.", + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", ) parser.add_argument( "-ptp", - "--produce-tensor-parallel-size", + "--producer-tensor-parallel-size", type=int, default=1, help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", @@ -206,7 +206,7 @@ if __name__ == "__main__": enforce_eager=True, enable_chunked_prefill=True, max_model_len=args.max_new_tokens + args.max_prompt_tokens, - tensor_parallel_size=args.produce_tensor_parallel_size, + tensor_parallel_size=args.producer_tensor_parallel_size, ) ) generate_config.update( @@ -276,7 +276,7 @@ if __name__ == "__main__": launch_distributed( num_producers=args.num_inferencer, - num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size), + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), num_consumer_procs=args.num_trainers, num_episodes=args.num_episodes, inference_batch_size=args.inference_batch_size,