mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList
|
||||
from .multinode_runner import MultiNodeRunner
|
||||
|
||||
# Constants that define our syntax
|
||||
NODE_SEP = ','
|
||||
NODE_SEP = ","
|
||||
|
||||
|
||||
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
|
||||
exit()
|
||||
|
||||
with open(hostfile_path, 'r') as fd:
|
||||
with open(hostfile_path, "r") as fd:
|
||||
device_pool = HostInfoList()
|
||||
|
||||
for line in fd.readlines():
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
if line == "":
|
||||
# skip empty lines
|
||||
continue
|
||||
|
||||
@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
|
||||
|
||||
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
|
||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
"""Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
|
||||
Examples:
|
||||
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
|
||||
@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
|
||||
|
||||
Returns:
|
||||
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
||||
'''
|
||||
"""
|
||||
|
||||
# Ensure include/exclude are mutually exclusive
|
||||
if include_str and exclude_str:
|
||||
@@ -136,16 +136,16 @@ def get_launch_command(
|
||||
|
||||
for k, v in arg_dict.items():
|
||||
if v:
|
||||
ret.append(f'--{k}={v}')
|
||||
ret.append(f"--{k}={v}")
|
||||
else:
|
||||
ret.append(f'--{k}')
|
||||
ret.append(f"--{k}")
|
||||
return ret
|
||||
|
||||
if extra_launch_args:
|
||||
extra_launch_args_dict = dict()
|
||||
for arg in extra_launch_args.split(','):
|
||||
if '=' in arg:
|
||||
k, v = arg.split('=')
|
||||
for arg in extra_launch_args.split(","):
|
||||
if "=" in arg:
|
||||
k, v = arg.split("=")
|
||||
extra_launch_args_dict[k] = v
|
||||
else:
|
||||
extra_launch_args_dict[arg] = None
|
||||
@@ -158,9 +158,14 @@ def get_launch_command(
|
||||
|
||||
if torch_version.minor < 9:
|
||||
cmd = [
|
||||
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}"
|
||||
sys.executable,
|
||||
"-m",
|
||||
"torch.distributed.launch",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={master_addr}",
|
||||
f"--master_port={master_port}",
|
||||
f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}",
|
||||
]
|
||||
else:
|
||||
# extra launch args for torch distributed launcher with torch >= 1.9
|
||||
@@ -174,17 +179,24 @@ def get_launch_command(
|
||||
|
||||
if torch_version.minor < 10:
|
||||
cmd = [
|
||||
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
|
||||
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
||||
sys.executable,
|
||||
"-m",
|
||||
"torch.distributed.run",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}",
|
||||
]
|
||||
else:
|
||||
cmd = [
|
||||
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
||||
"torchrun",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--nnodes={num_nodes}",
|
||||
f"--node_rank={node_rank}",
|
||||
]
|
||||
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
|
||||
|
||||
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
|
||||
cmd = ' '.join(cmd)
|
||||
cmd = " ".join(cmd)
|
||||
return cmd
|
||||
|
||||
|
||||
@@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None:
|
||||
# run on local node if not hosts or hostfile is given
|
||||
# add local node to host info list
|
||||
active_device_pool = HostInfoList()
|
||||
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
|
||||
localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
|
||||
active_device_pool.append(localhost_info)
|
||||
|
||||
# launch distributed processes
|
||||
runner = MultiNodeRunner()
|
||||
curr_path = os.path.abspath('.')
|
||||
curr_path = os.path.abspath(".")
|
||||
|
||||
# collect current path env
|
||||
env = dict()
|
||||
for k, v in os.environ.items():
|
||||
# do not support multi-line env var
|
||||
if v and '\n' not in v:
|
||||
if v and "\n" not in v:
|
||||
env[k] = v
|
||||
|
||||
# establish remote connection
|
||||
@@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None:
|
||||
|
||||
# execute distributed launching command
|
||||
for node_id, hostinfo in enumerate(active_device_pool):
|
||||
cmd = get_launch_command(master_addr=args.master_addr,
|
||||
master_port=args.master_port,
|
||||
nproc_per_node=args.nproc_per_node,
|
||||
user_script=args.user_script,
|
||||
user_args=args.user_args,
|
||||
node_rank=node_id,
|
||||
num_nodes=len(active_device_pool),
|
||||
extra_launch_args=args.extra_launch_args)
|
||||
cmd = get_launch_command(
|
||||
master_addr=args.master_addr,
|
||||
master_port=args.master_port,
|
||||
nproc_per_node=args.nproc_per_node,
|
||||
user_script=args.user_script,
|
||||
user_args=args.user_args,
|
||||
node_rank=node_id,
|
||||
num_nodes=len(active_device_pool),
|
||||
extra_launch_args=args.extra_launch_args,
|
||||
)
|
||||
runner.send(hostinfo=hostinfo, cmd=cmd)
|
||||
|
||||
# start training
|
||||
|
Reference in New Issue
Block a user