[cli] fixed a bug in user args and refactored the module structure

This commit is contained in:
FrankLeeeee
2022-04-19 15:14:54 +08:00
parent e761ad2cd7
commit f63e91d280
4 changed files with 108 additions and 152 deletions

View File

@@ -6,79 +6,10 @@ import sys
import os
import torch
from colossalai.logging import get_dist_logger
from colossalai.context import Config
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
from copy import deepcopy
def build_args_parser() -> ArgumentParser:
"""Helper function parsing the command line options."""
parser = ArgumentParser(description="colossal distributed training launcher")
parser.add_argument("-H",
"--hostfile",
type=str,
default="",
help="Hostfile path that defines the "
"device pool available to the job (e.g., "
"worker-name:number of slots)")
parser.add_argument("-i",
"--include",
type=str,
default="",
help="Specify computing devices to use during execution."
"String format is NODE_SPEC@NODE_SPEC"
"where NODE_SPEC=<worker-name>:<list-of-slots>")
parser.add_argument("-e",
"--exclude",
type=str,
default="",
help="Specify computing devices to NOT use during execution."
"Mutually exclusive with --include. Formatting"
"is the same as --include.")
parser.add_argument("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use.")
parser.add_argument("--num_gpus",
type=int,
default=-1,
help="Number of GPUs to use on each node.")
parser.add_argument("--master_port",
default=29500,
type=int,
help="(optional) Port used by PyTorch distributed for "
"communication during distributed training.")
parser.add_argument("--master_addr",
default="127.0.0.1",
type=str,
help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.")
parser.add_argument("--launcher",
default="torch",
type=str,
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, SLURM.")
parser.add_argument("--launcher_args",
default="",
type=str,
help="(optional) pass launcher specific arguments as a "
"single quoted argument.")
parser.add_argument("user_script",
type=str,
help="User script to launch, followed by any required "
"arguments.")
parser.add_argument('user_args', nargs=argparse.REMAINDER)
return parser
def fetch_hostfile(hostfile_path):
logger = get_dist_logger()
@@ -106,6 +37,7 @@ def fetch_hostfile(hostfile_path):
return device_pool
def _stable_remove_duplicates(data):
# Create a new list in the same order as original but with duplicates
# removed, should never be more than ~16 elements so simple is best
@@ -115,6 +47,7 @@ def _stable_remove_duplicates(data):
new_list.append(x)
return new_list
def parse_device_filter(host_info, include_str="", exclude_str=""):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
@@ -202,26 +135,27 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
return ordered_hosts
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
active_devices = collections.OrderedDict()
for hostname, slots in device_pool.items():
active_devices[hostname] = list(range(slots))
return parse_device_filter(active_devices,
include_str=inclusion,
exclude_str=exclusion)
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion)
def main(args=None):
logger = get_dist_logger()
assert args is not None, "args should not be None."
device_pool = fetch_hostfile(args.hostfile)
def launch_multi_processes(args):
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}'
# check
if args.hostfile:
device_pool = fetch_hostfile(args.hostfile)
else:
device_pool = None
active_devices = None
if device_pool:
active_devices = parse_inclusion_exclusion(device_pool,
args.include,
args.exclude)
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
if args.num_nodes > 0:
updated_active_devices = collections.OrderedDict()
for count, hostname in enumerate(active_devices.keys()):
@@ -244,16 +178,15 @@ def main(args=None):
else:
nproc_per_node = args.num_gpus
if torch.__version__ <= "1.09":
cmd = [sys.executable, "-u", "-m",
"torch.distributed.launch",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
cmd = [
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args
else:
cmd = ["torchrun",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
cmd = [
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args
else:
if args.launcher == "torch":
runner = PDSHRunner(args)
@@ -272,14 +205,10 @@ def main(args=None):
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
else:
env['PYTHONPATH'] = curr_path
cmd = runner.get_cmd(env, active_devices, args)
result = subprocess.Popen(cmd, env=env)
result.wait()
if result.returncode > 0:
sys.exit(result.returncode)
if __name__ == "__main__":
main()