fix typ and parameter description

This commit is contained in:
YeAnbang 2025-06-05 15:41:14 +08:00
parent 0d008110e7
commit 96faf54542

View File

@ -126,25 +126,25 @@ if __name__ == "__main__":
"--tensor-parallel-size", "--tensor-parallel-size",
type=int, type=int,
default=1, default=1,
help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.", help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
) )
parser.add_argument( parser.add_argument(
"-pp", "-pp",
"--pipeline-parallel-size", "--pipeline-parallel-size",
type=int, type=int,
default=1, default=1,
help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.", help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
) )
parser.add_argument( parser.add_argument(
"-zero", "-zero",
"--zero-stage", "--zero-stage",
type=int, type=int,
default=0, default=0,
help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.", help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
) )
parser.add_argument( parser.add_argument(
"-ptp", "-ptp",
"--produce-tensor-parallel-size", "--producer-tensor-parallel-size",
type=int, type=int,
default=1, default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
@ -206,7 +206,7 @@ if __name__ == "__main__":
enforce_eager=True, enforce_eager=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens, max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=args.produce_tensor_parallel_size, tensor_parallel_size=args.producer_tensor_parallel_size,
) )
) )
generate_config.update( generate_config.update(
@ -276,7 +276,7 @@ if __name__ == "__main__":
launch_distributed( launch_distributed(
num_producers=args.num_inferencer, num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size), num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
num_consumer_procs=args.num_trainers, num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size, inference_batch_size=args.inference_batch_size,