mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[cli] fixed a bug in user args and refactored the module structure
This commit is contained in:
@@ -6,79 +6,10 @@ import sys
|
||||
import os
|
||||
import torch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.context import Config
|
||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
|
||||
from copy import deepcopy
|
||||
|
||||
def build_args_parser() -> ArgumentParser:
|
||||
"""Helper function parsing the command line options."""
|
||||
|
||||
parser = ArgumentParser(description="colossal distributed training launcher")
|
||||
|
||||
parser.add_argument("-H",
|
||||
"--hostfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="Hostfile path that defines the "
|
||||
"device pool available to the job (e.g., "
|
||||
"worker-name:number of slots)")
|
||||
|
||||
parser.add_argument("-i",
|
||||
"--include",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify computing devices to use during execution."
|
||||
"String format is NODE_SPEC@NODE_SPEC"
|
||||
"where NODE_SPEC=<worker-name>:<list-of-slots>")
|
||||
|
||||
parser.add_argument("-e",
|
||||
"--exclude",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify computing devices to NOT use during execution."
|
||||
"Mutually exclusive with --include. Formatting"
|
||||
"is the same as --include.")
|
||||
|
||||
parser.add_argument("--num_nodes",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Total number of worker nodes to use.")
|
||||
|
||||
parser.add_argument("--num_gpus",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of GPUs to use on each node.")
|
||||
|
||||
parser.add_argument("--master_port",
|
||||
default=29500,
|
||||
type=int,
|
||||
help="(optional) Port used by PyTorch distributed for "
|
||||
"communication during distributed training.")
|
||||
|
||||
parser.add_argument("--master_addr",
|
||||
default="127.0.0.1",
|
||||
type=str,
|
||||
help="(optional) IP address of node 0, will be "
|
||||
"inferred via 'hostname -I' if not specified.")
|
||||
|
||||
parser.add_argument("--launcher",
|
||||
default="torch",
|
||||
type=str,
|
||||
help="(optional) choose launcher backend for multi-node "
|
||||
"training. Options currently include PDSH, OpenMPI, SLURM.")
|
||||
|
||||
parser.add_argument("--launcher_args",
|
||||
default="",
|
||||
type=str,
|
||||
help="(optional) pass launcher specific arguments as a "
|
||||
"single quoted argument.")
|
||||
|
||||
parser.add_argument("user_script",
|
||||
type=str,
|
||||
help="User script to launch, followed by any required "
|
||||
"arguments.")
|
||||
|
||||
parser.add_argument('user_args', nargs=argparse.REMAINDER)
|
||||
|
||||
return parser
|
||||
|
||||
def fetch_hostfile(hostfile_path):
|
||||
logger = get_dist_logger()
|
||||
@@ -106,6 +37,7 @@ def fetch_hostfile(hostfile_path):
|
||||
|
||||
return device_pool
|
||||
|
||||
|
||||
def _stable_remove_duplicates(data):
|
||||
# Create a new list in the same order as original but with duplicates
|
||||
# removed, should never be more than ~16 elements so simple is best
|
||||
@@ -115,6 +47,7 @@ def _stable_remove_duplicates(data):
|
||||
new_list.append(x)
|
||||
return new_list
|
||||
|
||||
|
||||
def parse_device_filter(host_info, include_str="", exclude_str=""):
|
||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
|
||||
@@ -202,26 +135,27 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
|
||||
|
||||
return ordered_hosts
|
||||
|
||||
|
||||
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
||||
active_devices = collections.OrderedDict()
|
||||
for hostname, slots in device_pool.items():
|
||||
active_devices[hostname] = list(range(slots))
|
||||
|
||||
return parse_device_filter(active_devices,
|
||||
include_str=inclusion,
|
||||
exclude_str=exclusion)
|
||||
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion)
|
||||
|
||||
def main(args=None):
|
||||
logger = get_dist_logger()
|
||||
assert args is not None, "args should not be None."
|
||||
|
||||
device_pool = fetch_hostfile(args.hostfile)
|
||||
|
||||
def launch_multi_processes(args):
|
||||
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}'
|
||||
|
||||
# check
|
||||
if args.hostfile:
|
||||
device_pool = fetch_hostfile(args.hostfile)
|
||||
else:
|
||||
device_pool = None
|
||||
|
||||
active_devices = None
|
||||
if device_pool:
|
||||
active_devices = parse_inclusion_exclusion(device_pool,
|
||||
args.include,
|
||||
args.exclude)
|
||||
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
|
||||
if args.num_nodes > 0:
|
||||
updated_active_devices = collections.OrderedDict()
|
||||
for count, hostname in enumerate(active_devices.keys()):
|
||||
@@ -244,16 +178,15 @@ def main(args=None):
|
||||
else:
|
||||
nproc_per_node = args.num_gpus
|
||||
if torch.__version__ <= "1.09":
|
||||
cmd = [sys.executable, "-u", "-m",
|
||||
"torch.distributed.launch",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
|
||||
cmd = [
|
||||
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
|
||||
] + [args.user_script] + args.user_args
|
||||
else:
|
||||
cmd = ["torchrun",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
|
||||
cmd = [
|
||||
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"
|
||||
] + [args.user_script] + args.user_args
|
||||
else:
|
||||
if args.launcher == "torch":
|
||||
runner = PDSHRunner(args)
|
||||
@@ -272,14 +205,10 @@ def main(args=None):
|
||||
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
|
||||
else:
|
||||
env['PYTHONPATH'] = curr_path
|
||||
|
||||
|
||||
cmd = runner.get_cmd(env, active_devices, args)
|
||||
|
||||
|
||||
result = subprocess.Popen(cmd, env=env)
|
||||
result.wait()
|
||||
if result.returncode > 0:
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user