mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,56 +5,81 @@ from colossalai.context import Config
|
||||
from .run import launch_multi_processes
|
||||
|
||||
|
||||
@click.command(help="Launch distributed training on a single node or multiple nodes",
|
||||
context_settings=dict(ignore_unknown_options=True))
|
||||
@click.option("-H",
|
||||
"-host",
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="the list of hostnames to launch in the format <host1>,<host2>")
|
||||
@click.command(
|
||||
help="Launch distributed training on a single node or multiple nodes",
|
||||
context_settings=dict(ignore_unknown_options=True),
|
||||
)
|
||||
@click.option(
|
||||
"-H",
|
||||
"-host",
|
||||
"--host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="the list of hostnames to launch in the format <host1>,<host2>",
|
||||
)
|
||||
@click.option(
|
||||
"--hostfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
|
||||
@click.option("--include",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
|
||||
" only effective when used with --hostfile.")
|
||||
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
|
||||
)
|
||||
@click.option(
|
||||
"--include",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
|
||||
" only effective when used with --hostfile.",
|
||||
)
|
||||
@click.option(
|
||||
"--exclude",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
|
||||
" only effective when used with --hostfile.")
|
||||
@click.option("--num_nodes",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Total number of worker nodes to use, only effective when used with --hostfile.")
|
||||
help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
|
||||
" only effective when used with --hostfile.",
|
||||
)
|
||||
@click.option(
|
||||
"--num_nodes",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Total number of worker nodes to use, only effective when used with --hostfile.",
|
||||
)
|
||||
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
|
||||
@click.option("--master_port",
|
||||
type=int,
|
||||
default=29500,
|
||||
help="(optional) Port used by PyTorch distributed for communication during distributed training.")
|
||||
@click.option("--master_addr",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
|
||||
@click.option(
|
||||
"--master_port",
|
||||
type=int,
|
||||
default=29500,
|
||||
help="(optional) Port used by PyTorch distributed for communication during distributed training.",
|
||||
)
|
||||
@click.option(
|
||||
"--master_addr",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
|
||||
)
|
||||
@click.option(
|
||||
"--extra_launch_args",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
|
||||
"This will be converted to --arg1=1 --arg2=2 during execution")
|
||||
help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
|
||||
"This will be converted to --arg1=1 --arg2=2 during execution",
|
||||
)
|
||||
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
|
||||
@click.argument("user_script", type=str)
|
||||
@click.argument('user_args', nargs=-1)
|
||||
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
|
||||
master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
|
||||
@click.argument("user_args", nargs=-1)
|
||||
def run(
|
||||
host: str,
|
||||
hostfile: str,
|
||||
num_nodes: int,
|
||||
nproc_per_node: int,
|
||||
include: str,
|
||||
exclude: str,
|
||||
master_addr: str,
|
||||
master_port: int,
|
||||
extra_launch_args: str,
|
||||
ssh_port: int,
|
||||
user_script: str,
|
||||
user_args: str,
|
||||
) -> None:
|
||||
"""
|
||||
To launch multiple processes on a single node or multiple nodes via command line.
|
||||
|
||||
@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
|
||||
# run with hostfile excluding the hosts selected
|
||||
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
|
||||
"""
|
||||
if not user_script.endswith('.py'):
|
||||
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
|
||||
if not user_script.endswith(".py"):
|
||||
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
|
||||
exit()
|
||||
|
||||
args_dict = locals()
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import socket
|
||||
from typing import List
|
||||
|
||||
|
||||
class HostInfo:
|
||||
@@ -34,7 +33,7 @@ class HostInfo:
|
||||
"""
|
||||
|
||||
if port is None:
|
||||
port = 22 # no port specified, lets just use the ssh port
|
||||
port = 22 # no port specified, lets just use the ssh port
|
||||
|
||||
# socket.getfqdn("127.0.0.1") does not return localhost
|
||||
# on some users' machines
|
||||
@@ -50,7 +49,7 @@ class HostInfo:
|
||||
return localaddrs == targetaddrs
|
||||
|
||||
def __str__(self):
|
||||
return f'hostname: {self.hostname}, port: {self.port}'
|
||||
return f"hostname: {self.hostname}, port: {self.port}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
@@ -7,8 +7,13 @@ import fabric
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
|
||||
|
||||
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
|
||||
send_conn: mp_connection.Connection, env: dict) -> None:
|
||||
def run_on_host(
|
||||
hostinfo: HostInfo,
|
||||
workdir: str,
|
||||
recv_conn: mp_connection.Connection,
|
||||
send_conn: mp_connection.Connection,
|
||||
env: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Use fabric connection to execute command on local or remote hosts.
|
||||
|
||||
@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
|
||||
|
||||
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
|
||||
finish = False
|
||||
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
|
||||
env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
|
||||
|
||||
# keep listening until exit
|
||||
while not finish:
|
||||
# receive cmd
|
||||
cmds = recv_conn.recv()
|
||||
|
||||
if cmds == 'exit':
|
||||
if cmds == "exit":
|
||||
# exit from the loop
|
||||
finish = True
|
||||
break
|
||||
@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
|
||||
else:
|
||||
# execute on the remote machine
|
||||
fab_conn.run(cmds, hide=False)
|
||||
send_conn.send('success')
|
||||
send_conn.send("success")
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
|
||||
)
|
||||
send_conn.send('failure')
|
||||
send_conn.send("failure")
|
||||
|
||||
# shutdown
|
||||
send_conn.send("finish")
|
||||
@@ -96,8 +101,7 @@ class MultiNodeRunner:
|
||||
cmd (str): the command to execute
|
||||
"""
|
||||
|
||||
assert hostinfo.hostname in self.master_send_conns, \
|
||||
f'{hostinfo} is not found in the current connections'
|
||||
assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
|
||||
conn = self.master_send_conns[hostinfo.hostname]
|
||||
conn.send(cmd)
|
||||
|
||||
@@ -107,7 +111,7 @@ class MultiNodeRunner:
|
||||
"""
|
||||
|
||||
for hostname, conn in self.master_send_conns.items():
|
||||
conn.send('exit')
|
||||
conn.send("exit")
|
||||
|
||||
def recv_from_all(self) -> dict:
|
||||
"""
|
||||
|
@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList
|
||||
from .multinode_runner import MultiNodeRunner
|
||||
|
||||
# Constants that define our syntax
|
||||
NODE_SEP = ','
|
||||
NODE_SEP = ","
|
||||
|
||||
|
||||
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
|
||||
exit()
|
||||
|
||||
with open(hostfile_path, 'r') as fd:
|
||||
with open(hostfile_path, "r") as fd:
|
||||
device_pool = HostInfoList()
|
||||
|
||||
for line in fd.readlines():
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
if line == "":
|
||||
# skip empty lines
|
||||
continue
|
||||
|
||||
@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
|
||||
|
||||
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
|
||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
"""Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
|
||||
Examples:
|
||||
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
|
||||
@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
|
||||
|
||||
Returns:
|
||||
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
||||
'''
|
||||
"""
|
||||
|
||||
# Ensure include/exclude are mutually exclusive
|
||||
if include_str and exclude_str:
|
||||
@@ -136,16 +136,16 @@ def get_launch_command(
|
||||
|
||||
for k, v in arg_dict.items():
|
||||
if v:
|
||||
ret.append(f'--{k}={v}')
|
||||
ret.append(f"--{k}={v}")
|
||||
else:
|
||||
ret.append(f'--{k}')
|
||||
ret.append(f"--{k}")
|
||||
return ret
|
||||
|
||||
if extra_launch_args:
|
||||
extra_launch_args_dict = dict()
|
||||
for arg in extra_launch_args.split(','):
|
||||
if '=' in arg:
|
||||
k, v = arg.split('=')
|
||||
for arg in extra_launch_args.split(","):
|
||||
if "=" in arg:
|
||||
k, v = arg.split("=")
|
||||
extra_launch_args_dict[k] = v
|
||||
else:
|
||||
extra_launch_args_dict[arg] = None
|
||||
@@ -158,9 +158,14 @@ def get_launch_command(
|
||||
|
||||
if torch_version.minor < 9:
|
||||
cmd = [
|
||||
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}"
|
||||
sys.executable,
|
||||
"-m",
|
||||
"torch.distributed.launch",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={master_addr}",
|
||||
f"--master_port={master_port}",
|
||||
f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}",
|
||||
]
|
||||
else:
|
||||
# extra launch args for torch distributed launcher with torch >= 1.9
|
||||
@@ -174,17 +179,24 @@ def get_launch_command(
|
||||
|
||||
if torch_version.minor < 10:
|
||||
cmd = [
|
||||
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
|
||||
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
||||
sys.executable,
|
||||
"-m",
|
||||
"torch.distributed.run",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}",
|
||||
]
|
||||
else:
|
||||
cmd = [
|
||||
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
||||
"torchrun",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}",
|
||||
]
|
||||
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
|
||||
|
||||
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
|
||||
cmd = ' '.join(cmd)
|
||||
cmd = " ".join(cmd)
|
||||
return cmd
|
||||
|
||||
|
||||
@@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None:
|
||||
# run on local node if not hosts or hostfile is given
|
||||
# add local node to host info list
|
||||
active_device_pool = HostInfoList()
|
||||
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
|
||||
localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
|
||||
active_device_pool.append(localhost_info)
|
||||
|
||||
# launch distributed processes
|
||||
runner = MultiNodeRunner()
|
||||
curr_path = os.path.abspath('.')
|
||||
curr_path = os.path.abspath(".")
|
||||
|
||||
# collect current path env
|
||||
env = dict()
|
||||
for k, v in os.environ.items():
|
||||
# do not support multi-line env var
|
||||
if v and '\n' not in v:
|
||||
if v and "\n" not in v:
|
||||
env[k] = v
|
||||
|
||||
# establish remote connection
|
||||
@@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None:
|
||||
|
||||
# execute distributed launching command
|
||||
for node_id, hostinfo in enumerate(active_device_pool):
|
||||
cmd = get_launch_command(master_addr=args.master_addr,
|
||||
master_port=args.master_port,
|
||||
nproc_per_node=args.nproc_per_node,
|
||||
user_script=args.user_script,
|
||||
user_args=args.user_args,
|
||||
node_rank=node_id,
|
||||
num_nodes=len(active_device_pool),
|
||||
extra_launch_args=args.extra_launch_args)
|
||||
cmd = get_launch_command(
|
||||
master_addr=args.master_addr,
|
||||
master_port=args.master_port,
|
||||
nproc_per_node=args.nproc_per_node,
|
||||
user_script=args.user_script,
|
||||
user_args=args.user_args,
|
||||
node_rank=node_id,
|
||||
num_nodes=len(active_device_pool),
|
||||
extra_launch_args=args.extra_launch_args,
|
||||
)
|
||||
runner.send(hostinfo=hostinfo, cmd=cmd)
|
||||
|
||||
# start training
|
||||
|
Reference in New Issue
Block a user