mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 06:29:09 +00:00
small fix
This commit is contained in:
parent
7d658402da
commit
dd74f496c0
@ -44,7 +44,6 @@ def launch_distributed(
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
num_generations: int = 8,
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
@ -53,7 +52,7 @@ def launch_distributed(
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
response_format_tags: Dict[Any] = None,
|
||||
response_format_tags: Dict[str, Any] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
|
@ -77,12 +77,6 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Ray master address for multi-node distributed training, Optional",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch_ddp_master_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Torch DDP master address for multi-node distributed training, Optional",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch_ddp_master_port",
|
||||
type=int,
|
||||
@ -125,7 +119,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||
parser.add_argument(
|
||||
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
|
||||
"-rft", "--response_format_tags", type=str, default=None, help="Optional json string of the response format tag"
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
@ -186,11 +180,11 @@ if __name__ == "__main__":
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.7,
|
||||
gpu_memory_utilization=0.5,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||
tensor_parallel_size=1,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
@ -267,7 +261,6 @@ if __name__ == "__main__":
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr=args.torch_ddp_master_address,
|
||||
master_port=args.torch_ddp_master_port,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
|
Loading…
Reference in New Issue
Block a user