mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
small fix
This commit is contained in:
parent
7d658402da
commit
dd74f496c0
@ -44,7 +44,6 @@ def launch_distributed(
|
|||||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||||
inference_backend: str = "transformers",
|
inference_backend: str = "transformers",
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
master_addr: str = "localhost",
|
|
||||||
master_port: int = 29500,
|
master_port: int = 29500,
|
||||||
core_algo: str = "GRPO",
|
core_algo: str = "GRPO",
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
@ -53,7 +52,7 @@ def launch_distributed(
|
|||||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||||
eval_interval: int = 100,
|
eval_interval: int = 100,
|
||||||
eval_save_dir: Optional[str] = None,
|
eval_save_dir: Optional[str] = None,
|
||||||
response_format_tags: Dict[Any] = None,
|
response_format_tags: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
|
@ -77,12 +77,6 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="Ray master address for multi-node distributed training, Optional",
|
help="Ray master address for multi-node distributed training, Optional",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--torch_ddp_master_address",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Torch DDP master address for multi-node distributed training, Optional",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch_ddp_master_port",
|
"--torch_ddp_master_port",
|
||||||
type=int,
|
type=int,
|
||||||
@ -125,7 +119,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
|
"-rft", "--response_format_tags", type=str, default=None, help="Optional json string of the response format tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Logging/Checkpointing parameters
|
# Logging/Checkpointing parameters
|
||||||
@ -186,11 +180,11 @@ if __name__ == "__main__":
|
|||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
dict(
|
dict(
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.5,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
@ -267,7 +261,6 @@ if __name__ == "__main__":
|
|||||||
# "max_norm": 1.0,
|
# "max_norm": 1.0,
|
||||||
# }, # for pp, tp
|
# }, # for pp, tp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr=args.torch_ddp_master_address,
|
|
||||||
master_port=args.torch_ddp_master_port,
|
master_port=args.torch_ddp_master_port,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
project_name=args.project,
|
project_name=args.project,
|
||||||
|
Loading…
Reference in New Issue
Block a user