mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[booster] implemented the cluster module (#3191)
* [booster] implemented the cluster module * polish code
This commit is contained in:
75
colossalai/cluster/process_group_manager.py
Normal file
75
colossalai/cluster/process_group_manager.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import List
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class ProcessGroupManager:
|
||||
"""
|
||||
ProcessGroupManager is used to manage the process groups in the cluster.
|
||||
|
||||
There are some terms used in this class:
|
||||
- pg: the short name for process group
|
||||
- pg_name: the name of the process group
|
||||
- pg_size: the world size of the process group
|
||||
- rank: the rank of the current process in the process group
|
||||
- world_size: the total number of processes in the process group
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.pg_store = dict()
|
||||
|
||||
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
|
||||
"""
|
||||
Get a process group by name. If the process group does not exist, it will be created.
|
||||
|
||||
Args:
|
||||
name (str): name of the process group
|
||||
ranks (List[int]): ranks of the process group
|
||||
backend (str, optional): backend of the process group. Defaults to 'nccl'.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: the process group
|
||||
"""
|
||||
if name not in self.pg_store:
|
||||
pg = dist.new_group(ranks=ranks, backend=backend)
|
||||
self.pg_store[name] = pg
|
||||
return pg
|
||||
else:
|
||||
raise ValueError(f'Process group {name} already exists.')
|
||||
|
||||
def get(self, name: str) -> ProcessGroup:
|
||||
"""
|
||||
Get a process group by name.
|
||||
|
||||
Args:
|
||||
name (str): name of the process group
|
||||
|
||||
Returns:
|
||||
ProcessGroup: the process group
|
||||
"""
|
||||
if name in self.pg_store:
|
||||
return self.pg_store[name]
|
||||
else:
|
||||
raise ValueError(f'Process group {name} does not exist.')
|
||||
|
||||
def destroy(self, name: str) -> None:
|
||||
"""
|
||||
Destroy a process group by name.
|
||||
|
||||
Args:
|
||||
name (str): name of the process group
|
||||
"""
|
||||
if name in self.pg_store:
|
||||
dist.destroy_process_group(self.pg_store[name])
|
||||
del self.pg_store[name]
|
||||
else:
|
||||
raise ValueError(f'Process group {name} does not exist.')
|
||||
|
||||
def destroy_all(self) -> None:
|
||||
"""
|
||||
Destroy all process groups.
|
||||
"""
|
||||
for name in self.pg_store:
|
||||
dist.destroy_process_group(self.pg_store[name])
|
||||
self.pg_store.clear()
|
Reference in New Issue
Block a user