mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[CLI] refactored the launch CLI and fixed bugs in multi-node launching (#844)
* [cli] fixed multi-node job launching * [cli] fixed a bug in version comparison * [cli] support launching with env var * [cli] fixed multi-node job launching * [cli] fixed a bug in version comparison * [cli] support launching with env var * added docstring * [cli] added extra launch arguments * [cli] added default launch rdzv args * [cli] fixed version comparison * [cli] added docstring examples and requierment * polish docstring * polish code * polish code
This commit is contained in:
@@ -1,69 +1,120 @@
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from shlex import quote
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
import fabric
|
||||
from fabric import Connection
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing import connection as mp_connection
|
||||
import click
|
||||
|
||||
|
||||
class MultiNodeRunner(ABC):
|
||||
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
|
||||
send_conn: mp_connection.Connection, env: dict) -> None:
|
||||
"""
|
||||
Use fabric connection to execute command on local or remote hosts.
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.user_arguments = self.args.user_args
|
||||
self.user_script = args.user_script
|
||||
self.exports = {}
|
||||
Args:
|
||||
hostinfo (HostInfo): host information
|
||||
workdir (str): the directory to execute the command
|
||||
recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
|
||||
send_conn (multiprocessing.connection.Connection): send messages to the master receiver
|
||||
env (dict): a dictionary for environment variables
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def backend_exists(self):
|
||||
"""Return whether the corresponding backend exists"""
|
||||
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
|
||||
finish = False
|
||||
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
|
||||
|
||||
@abstractmethod
|
||||
def get_cmd(self, environment, active_devices):
|
||||
"""Return the command to execute on node"""
|
||||
# keep listening until exit
|
||||
while not finish:
|
||||
# receive cmd
|
||||
cmds = recv_conn.recv()
|
||||
|
||||
def add_export(self, key, var):
|
||||
self.exports[key.strip()] = var.strip()
|
||||
if cmds == 'exit':
|
||||
# exit from the loop
|
||||
finish = True
|
||||
break
|
||||
else:
|
||||
# execute the commands
|
||||
try:
|
||||
# cd to execute directory
|
||||
with fab_conn.cd(workdir):
|
||||
# propagate the runtime environment
|
||||
with fab_conn.prefix(f"export {env_msg}"):
|
||||
if hostinfo.is_local_host:
|
||||
# execute on the local machine
|
||||
fab_conn.local(cmds, hide=False)
|
||||
else:
|
||||
# execute on the remote machine
|
||||
fab_conn.run(cmds, hide=False)
|
||||
send_conn.send('success')
|
||||
except:
|
||||
click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
|
||||
send_conn.send('failure')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the backend"""
|
||||
return self.__class__.__name__
|
||||
# shutdown
|
||||
send_conn.send("finish")
|
||||
fab_conn.close()
|
||||
|
||||
|
||||
class PDSHRunner(MultiNodeRunner):
|
||||
class MultiNodeRunner:
|
||||
"""
|
||||
A runner to execute commands on an array of machines. This runner
|
||||
is inspired by Nezha (https://github.com/zhuzilin/NeZha).
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
def __init__(self):
|
||||
self.processes = {}
|
||||
self.master_send_conns = {}
|
||||
self.master_recv_conns = {}
|
||||
|
||||
def backend_exists(self):
|
||||
return shutil.which('pdsh')
|
||||
def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
|
||||
"""
|
||||
Establish connections to a list of hosts
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "pdsh"
|
||||
Args:
|
||||
host_info_list (HostInfoList): a list of HostInfo objects
|
||||
workdir (str): the directory where command is executed
|
||||
env (dict): environment variables to propagate to hosts
|
||||
"""
|
||||
for hostinfo in host_info_list:
|
||||
master_send_conn, worker_recv_conn = Pipe()
|
||||
master_recv_conn, worker_send_conn = Pipe()
|
||||
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
|
||||
p.start()
|
||||
self.processes[hostinfo.hostname] = p
|
||||
self.master_recv_conns[hostinfo.hostname] = master_recv_conn
|
||||
self.master_send_conns[hostinfo.hostname] = master_send_conn
|
||||
|
||||
def parse_user_args(self):
|
||||
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
|
||||
def send(self, hostinfo: HostInfo, cmd: str) -> None:
|
||||
"""
|
||||
Send a command to a local/remote host.
|
||||
|
||||
def get_cmd(self, environment, active_devices, args):
|
||||
environment['PDSH_RCMD_TYPE'] = 'ssh'
|
||||
Args:
|
||||
hostinfo (HostInfo): host information
|
||||
cmd (str): the command to execute
|
||||
"""
|
||||
|
||||
active_workers = ",".join(active_devices.keys())
|
||||
print("Running on the following workers: %s" % active_workers)
|
||||
assert hostinfo.hostname in self.master_send_conns, \
|
||||
f'{hostinfo} is not found in the current connections'
|
||||
conn = self.master_send_conns[hostinfo.hostname]
|
||||
conn.send(cmd)
|
||||
|
||||
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
|
||||
def stop_all(self) -> None:
|
||||
"""
|
||||
Stop connections to all hosts.
|
||||
"""
|
||||
|
||||
exports = ""
|
||||
for key, val in self.exports.items():
|
||||
exports += f"export {key}={quote(val)}; "
|
||||
for hostname, conn in self.master_send_conns.items():
|
||||
conn.send('exit')
|
||||
|
||||
# https://linux.die.net/man/1/pdsh
|
||||
# %n will be replaced by pdsh command
|
||||
colossal_launch = [
|
||||
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
|
||||
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"
|
||||
]
|
||||
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
|
||||
def recv_from_all(self) -> dict:
|
||||
"""
|
||||
Receive messages from all hosts
|
||||
|
||||
Returns:
|
||||
msg_from_node (dict): a dictionry which contains messages from each node
|
||||
"""
|
||||
|
||||
msg_from_node = dict()
|
||||
for hostname, conn in self.master_recv_conns.items():
|
||||
msg_from_node[hostname] = conn.recv()
|
||||
return msg_from_node
|
||||
|
Reference in New Issue
Block a user