mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
[booster] implemented the cluster module (#3191)
* [booster] implemented the cluster module * polish code
This commit is contained in:
parent
019a847432
commit
e3ad88fb48
5
colossalai/cluster/__init__.py
Normal file
5
colossalai/cluster/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .device_mesh_manager import DeviceMeshManager
|
||||||
|
from .dist_coordinator import DistCoordinator
|
||||||
|
from .process_group_manager import ProcessGroupManager
|
||||||
|
|
||||||
|
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
|
36
colossalai/cluster/device_mesh_manager.py
Normal file
36
colossalai/cluster/device_mesh_manager.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceMeshManager:
|
||||||
|
"""
|
||||||
|
Device mesh manager is responsible for creating and managing device meshes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.device_mesh_store = dict()
|
||||||
|
|
||||||
|
def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
|
||||||
|
"""
|
||||||
|
Create a device mesh and store it in the manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the device mesh
|
||||||
|
*args: args for DeviceMesh
|
||||||
|
**kwargs: kwargs for DeviceMesh
|
||||||
|
"""
|
||||||
|
# TODO(Yuliang): replace *args, **kwargs with explicit arguments
|
||||||
|
if name not in self.device_mesh_store:
|
||||||
|
device_mesh = DeviceMesh(*args, **kwargs)
|
||||||
|
self.device_mesh_store[name] = device_mesh
|
||||||
|
return device_mesh
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Device mesh {name} already exists.')
|
||||||
|
|
||||||
|
def get(self, name: str) -> DeviceMesh:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def destroy_all(self):
|
||||||
|
pass
|
158
colossalai/cluster/dist_coordinator.py
Normal file
158
colossalai/cluster/dist_coordinator.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
|
|
||||||
|
|
||||||
|
class DistCoordinator(metaclass=SingletonMeta):
|
||||||
|
"""
|
||||||
|
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
|
||||||
|
class in the whole program.
|
||||||
|
|
||||||
|
There are some terms that are used in this class:
|
||||||
|
- rank: the rank of the current process
|
||||||
|
- world size: the total number of processes
|
||||||
|
- local rank: the rank of the current process on the current node
|
||||||
|
- master: the process with rank 0
|
||||||
|
- node master: the process with local rank 0 on the current node
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||||
|
>>> coordinator = DistCoordinator()
|
||||||
|
>>>
|
||||||
|
>>> if coordinator.is_master():
|
||||||
|
>>> do_something()
|
||||||
|
>>>
|
||||||
|
>>> coordinator.print_on_master('hello world')
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
rank (int): the rank of the current process
|
||||||
|
world_size (int): the total number of processes
|
||||||
|
local_rank (int): the rank of the current process on the current node
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
assert dist.is_initialized(
|
||||||
|
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
|
||||||
|
self._rank = dist.get_rank()
|
||||||
|
self._world_size = dist.get_world_size()
|
||||||
|
# this is often passed by launchers such as torchrun
|
||||||
|
self._local_rank = os.environ.get('LOCAL_RANK', -1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rank(self) -> int:
|
||||||
|
return self._rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return self._world_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_rank(self) -> int:
|
||||||
|
return self._local_rank
|
||||||
|
|
||||||
|
def _assert_local_rank_set(self):
|
||||||
|
"""
|
||||||
|
Assert that the local rank is set. This is often passed by launchers such as torchrun.
|
||||||
|
"""
|
||||||
|
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
|
||||||
|
|
||||||
|
def is_master(self, process_group: ProcessGroup = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current process is the master process, False otherwise
|
||||||
|
"""
|
||||||
|
rank = dist.get_rank(group=process_group)
|
||||||
|
return rank == 0
|
||||||
|
|
||||||
|
def is_node_master(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current process is the master process on the current node (local rank is 0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current process is the master process on the current node, False otherwise
|
||||||
|
"""
|
||||||
|
self._assert_local_rank_set()
|
||||||
|
return self.local_rank == 0
|
||||||
|
|
||||||
|
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current process is the last process, False otherwise
|
||||||
|
"""
|
||||||
|
rank = dist.get_rank(group=process_group)
|
||||||
|
world_size = dist.get_world_size(group=process_group)
|
||||||
|
return rank == world_size - 1
|
||||||
|
|
||||||
|
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
Print message only from rank 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg (str): message to print
|
||||||
|
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
||||||
|
"""
|
||||||
|
rank = dist.get_rank(group=process_group)
|
||||||
|
if rank == 0:
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
def print_on_node_master(self, msg: str):
|
||||||
|
"""
|
||||||
|
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg (str): message to print
|
||||||
|
"""
|
||||||
|
self._assert_local_rank_set()
|
||||||
|
if self.local_rank == 0:
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
This context manager is used to allow one process to execute while blocking all
|
||||||
|
other processes in the same process group. This is often useful when downloading is required
|
||||||
|
as we only want to download in one process to prevent file corruption.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from colossalai.cluster import DistCoordinator
|
||||||
|
>>> dist_coordinator = DistCoordinator()
|
||||||
|
>>> with dist_coordinator.priority_execution():
|
||||||
|
>>> dataset = CIFAR10(root='./data', download=True)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||||
|
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
|
||||||
|
"""
|
||||||
|
rank = dist.get_rank(group=process_group)
|
||||||
|
should_block = rank != executor_rank
|
||||||
|
|
||||||
|
if should_block:
|
||||||
|
dist.barrier(group=process_group)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if not should_block:
|
||||||
|
dist.barrier(group=process_group)
|
||||||
|
|
||||||
|
def destroy(self, process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
Destroy the distributed process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
||||||
|
"""
|
||||||
|
dist.destroy_process_group(process_group)
|
75
colossalai/cluster/process_group_manager.py
Normal file
75
colossalai/cluster/process_group_manager.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroupManager:
|
||||||
|
"""
|
||||||
|
ProcessGroupManager is used to manage the process groups in the cluster.
|
||||||
|
|
||||||
|
There are some terms used in this class:
|
||||||
|
- pg: the short name for process group
|
||||||
|
- pg_name: the name of the process group
|
||||||
|
- pg_size: the world size of the process group
|
||||||
|
- rank: the rank of the current process in the process group
|
||||||
|
- world_size: the total number of processes in the process group
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pg_store = dict()
|
||||||
|
|
||||||
|
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
|
||||||
|
"""
|
||||||
|
Get a process group by name. If the process group does not exist, it will be created.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the process group
|
||||||
|
ranks (List[int]): ranks of the process group
|
||||||
|
backend (str, optional): backend of the process group. Defaults to 'nccl'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProcessGroup: the process group
|
||||||
|
"""
|
||||||
|
if name not in self.pg_store:
|
||||||
|
pg = dist.new_group(ranks=ranks, backend=backend)
|
||||||
|
self.pg_store[name] = pg
|
||||||
|
return pg
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Process group {name} already exists.')
|
||||||
|
|
||||||
|
def get(self, name: str) -> ProcessGroup:
|
||||||
|
"""
|
||||||
|
Get a process group by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the process group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProcessGroup: the process group
|
||||||
|
"""
|
||||||
|
if name in self.pg_store:
|
||||||
|
return self.pg_store[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Process group {name} does not exist.')
|
||||||
|
|
||||||
|
def destroy(self, name: str) -> None:
|
||||||
|
"""
|
||||||
|
Destroy a process group by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the process group
|
||||||
|
"""
|
||||||
|
if name in self.pg_store:
|
||||||
|
dist.destroy_process_group(self.pg_store[name])
|
||||||
|
del self.pg_store[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Process group {name} does not exist.')
|
||||||
|
|
||||||
|
def destroy_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Destroy all process groups.
|
||||||
|
"""
|
||||||
|
for name in self.pg_store:
|
||||||
|
dist.destroy_process_group(self.pg_store[name])
|
||||||
|
self.pg_store.clear()
|
Loading…
Reference in New Issue
Block a user