mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial * polish code
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import click
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from colossalai.context import Config
|
||||
from .multinode_runner import MultiNodeRunner
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
from .hostinfo import HostInfo, HostInfoList
|
||||
from .multinode_runner import MultiNodeRunner
|
||||
|
||||
# Constants that define our syntax
|
||||
NODE_SEP = ','
|
||||
|
||||
@@ -15,7 +18,7 @@ NODE_SEP = ','
|
||||
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||
"""
|
||||
Parse the hostfile to obtain a list of hosts.
|
||||
|
||||
|
||||
A hostfile should look like:
|
||||
worker-0
|
||||
worker-1
|
||||
@@ -63,7 +66,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
|
||||
device_pool (HostInfoList): a list of HostInfo objects
|
||||
include_str (str): --include option passed by user, default None
|
||||
exclude_str (str): --exclude option passed by user, default None
|
||||
|
||||
|
||||
Returns:
|
||||
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
||||
'''
|
||||
@@ -192,7 +195,7 @@ def launch_multi_processes(args: Config) -> None:
|
||||
Launch multiple processes on a single node or multiple nodes.
|
||||
|
||||
The overall logic can be summarized as the pseudo code below:
|
||||
|
||||
|
||||
if hostfile given:
|
||||
hostinfo = parse_hostfile(hostfile)
|
||||
hostinfo = include_or_exclude_hosts(hostinfo)
|
||||
@@ -202,7 +205,7 @@ def launch_multi_processes(args: Config) -> None:
|
||||
launch_on_multi_nodes(hostinfo)
|
||||
else:
|
||||
launch_on_current_node()
|
||||
|
||||
|
||||
Args:
|
||||
args (Config): the arguments taken from command line
|
||||
|
||||
@@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
|
||||
extra_launch_args=args.extra_launch_args)
|
||||
runner.send(hostinfo=hostinfo, cmd=cmd)
|
||||
|
||||
runner.recv_from_all()
|
||||
# start training
|
||||
msg_from_node = runner.recv_from_all()
|
||||
has_error = False
|
||||
|
||||
# print node status
|
||||
click.echo("\n====== Training on All Nodes =====")
|
||||
for hostname, msg in msg_from_node.items():
|
||||
click.echo(f"{hostname}: {msg}")
|
||||
|
||||
# check if a process failed
|
||||
if msg == "failure":
|
||||
has_error = True
|
||||
|
||||
# stop all nodes
|
||||
runner.stop_all()
|
||||
runner.recv_from_all()
|
||||
|
||||
# receive the stop status
|
||||
msg_from_node = runner.recv_from_all()
|
||||
|
||||
# printe node status
|
||||
click.echo("\n====== Stopping All Nodes =====")
|
||||
for hostname, msg in msg_from_node.items():
|
||||
click.echo(f"{hostname}: {msg}")
|
||||
|
||||
# give the process an exit code
|
||||
# so that it behaves like a normal process
|
||||
if has_error:
|
||||
sys.exit(1)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
Reference in New Issue
Block a user