[CLI] refactored the launch CLI and fixed bugs in multi-node launching (#844)

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* added docstring

* [cli] added extra launch arguments

* [cli] added default launch rdzv args

* [cli] fixed version comparison

* [cli] added docstring examples and requierment

* polish docstring

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-04-24 13:26:26 +08:00
committed by GitHub
parent e5ea3fdeef
commit cf6d1c9284
5 changed files with 468 additions and 240 deletions

View File

@@ -1,69 +1,120 @@
import os
import sys
import shutil
from shlex import quote
from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger
import fabric
from fabric import Connection
from .hostinfo import HostInfo, HostInfoList
from multiprocessing import Pipe, Process
from multiprocessing import connection as mp_connection
import click
class MultiNodeRunner(ABC):
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection, env: dict) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
def __init__(self, args):
self.args = args
self.user_arguments = self.args.user_args
self.user_script = args.user_script
self.exports = {}
Args:
hostinfo (HostInfo): host information
workdir (str): the directory to execute the command
recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
send_conn (multiprocessing.connection.Connection): send messages to the master receiver
env (dict): a dictionary for environment variables
"""
@abstractmethod
def backend_exists(self):
"""Return whether the corresponding backend exists"""
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
@abstractmethod
def get_cmd(self, environment, active_devices):
"""Return the command to execute on node"""
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
def add_export(self, key, var):
self.exports[key.strip()] = var.strip()
if cmds == 'exit':
# exit from the loop
finish = True
break
else:
# execute the commands
try:
# cd to execute directory
with fab_conn.cd(workdir):
# propagate the runtime environment
with fab_conn.prefix(f"export {env_msg}"):
if hostinfo.is_local_host:
# execute on the local machine
fab_conn.local(cmds, hide=False)
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send('success')
except:
click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
send_conn.send('failure')
@property
def name(self):
"""Return the name of the backend"""
return self.__class__.__name__
# shutdown
send_conn.send("finish")
fab_conn.close()
class PDSHRunner(MultiNodeRunner):
class MultiNodeRunner:
"""
A runner to execute commands on an array of machines. This runner
is inspired by Nezha (https://github.com/zhuzilin/NeZha).
"""
def __init__(self, args):
super().__init__(args)
def __init__(self):
self.processes = {}
self.master_send_conns = {}
self.master_recv_conns = {}
def backend_exists(self):
return shutil.which('pdsh')
def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
"""
Establish connections to a list of hosts
@property
def name(self):
return "pdsh"
Args:
host_info_list (HostInfoList): a list of HostInfo objects
workdir (str): the directory where command is executed
env (dict): environment variables to propagate to hosts
"""
for hostinfo in host_info_list:
master_send_conn, worker_recv_conn = Pipe()
master_recv_conn, worker_send_conn = Pipe()
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
p.start()
self.processes[hostinfo.hostname] = p
self.master_recv_conns[hostinfo.hostname] = master_recv_conn
self.master_send_conns[hostinfo.hostname] = master_send_conn
def parse_user_args(self):
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
def send(self, hostinfo: HostInfo, cmd: str) -> None:
"""
Send a command to a local/remote host.
def get_cmd(self, environment, active_devices, args):
environment['PDSH_RCMD_TYPE'] = 'ssh'
Args:
hostinfo (HostInfo): host information
cmd (str): the command to execute
"""
active_workers = ",".join(active_devices.keys())
print("Running on the following workers: %s" % active_workers)
assert hostinfo.hostname in self.master_send_conns, \
f'{hostinfo} is not found in the current connections'
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
def stop_all(self) -> None:
"""
Stop connections to all hosts.
"""
exports = ""
for key, val in self.exports.items():
exports += f"export {key}={quote(val)}; "
for hostname, conn in self.master_send_conns.items():
conn.send('exit')
# https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command
colossal_launch = [
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
]
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
def recv_from_all(self) -> dict:
"""
Receive messages from all hosts
Returns:
msg_from_node (dict): a dictionry which contains messages from each node
"""
msg_from_node = dict()
for hostname, conn in self.master_recv_conns.items():
msg_from_node[hostname] = conn.recv()
return msg_from_node