[example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code
This commit is contained in:
Frank Lee
2023-01-11 15:17:17 +08:00
committed by GitHub
parent 5521af7877
commit 39163417a1
6 changed files with 82 additions and 65 deletions

View File

@@ -1,13 +1,16 @@
import click
import sys
import os
import torch
from colossalai.context import Config
from .multinode_runner import MultiNodeRunner
from .hostinfo import HostInfo, HostInfoList
import sys
from typing import List
import click
import torch
from packaging import version
from colossalai.context import Config
from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
NODE_SEP = ','
@@ -15,7 +18,7 @@ NODE_SEP = ','
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
"""
Parse the hostfile to obtain a list of hosts.
A hostfile should look like:
worker-0
worker-1
@@ -63,7 +66,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
device_pool (HostInfoList): a list of HostInfo objects
include_str (str): --include option passed by user, default None
exclude_str (str): --exclude option passed by user, default None
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
@@ -192,7 +195,7 @@ def launch_multi_processes(args: Config) -> None:
Launch multiple processes on a single node or multiple nodes.
The overall logic can be summarized as the pseudo code below:
if hostfile given:
hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo)
@@ -202,7 +205,7 @@ def launch_multi_processes(args: Config) -> None:
launch_on_multi_nodes(hostinfo)
else:
launch_on_current_node()
Args:
args (Config): the arguments taken from command line
@@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
extra_launch_args=args.extra_launch_args)
runner.send(hostinfo=hostinfo, cmd=cmd)
runner.recv_from_all()
# start training
msg_from_node = runner.recv_from_all()
has_error = False
# print node status
click.echo("\n====== Training on All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
# check if a process failed
if msg == "failure":
has_error = True
# stop all nodes
runner.stop_all()
runner.recv_from_all()
# receive the stop status
msg_from_node = runner.recv_from_all()
# printe node status
click.echo("\n====== Stopping All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
# give the process an exit code
# so that it behaves like a normal process
if has_error:
sys.exit(1)
else:
sys.exit(0)