small fix

This commit is contained in:
YeAnbang 2025-05-03 09:55:24 +08:00
parent 7d658402da
commit dd74f496c0
2 changed files with 4 additions and 12 deletions

View File

@ -44,7 +44,6 @@ def launch_distributed(
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
project_name: Optional[str] = None,
@ -53,7 +52,7 @@ def launch_distributed(
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
response_format_tags: Dict[Any] = None,
response_format_tags: Dict[str, Any] = None,
):
if core_algo not in ALGO_MAP:

View File

@ -77,12 +77,6 @@ if __name__ == "__main__":
default=None,
help="Ray master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--torch_ddp_master_address",
type=str,
default=None,
help="Torch DDP master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--torch_ddp_master_port",
type=int,
@ -125,7 +119,7 @@ if __name__ == "__main__":
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
parser.add_argument(
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
"-rft", "--response_format_tags", type=str, default=None, help="Optional json string of the response format tag"
)
# Logging/Checkpointing parameters
@ -186,11 +180,11 @@ if __name__ == "__main__":
elif args.backend == "vllm":
inference_model_config.update(
dict(
gpu_memory_utilization=0.7,
gpu_memory_utilization=0.5,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1,
tensor_parallel_size=2,
)
)
generate_config.update(
@ -267,7 +261,6 @@ if __name__ == "__main__":
# "max_norm": 1.0,
# }, # for pp, tp
inference_backend=args.backend,
master_addr=args.torch_ddp_master_address,
master_port=args.torch_ddp_master_port,
core_algo=args.algo,
project_name=args.project,