mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,4 +3,4 @@ from .dist_coordinator import DistCoordinator
|
||||
from .process_group_manager import ProcessGroupManager
|
||||
from .process_group_mesh import ProcessGroupMesh
|
||||
|
||||
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh']
|
||||
__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
|
||||
|
@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
@dataclass
|
||||
class DeviceMeshInfo:
|
||||
'''
|
||||
"""
|
||||
This class is used to store the information used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
|
||||
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
|
||||
'''
|
||||
"""
|
||||
|
||||
physical_ids: List[int]
|
||||
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
|
||||
|
||||
@@ -24,16 +25,18 @@ class DeviceMeshInfo:
|
||||
if self.mesh_shape is not None:
|
||||
world_size = len(self.physical_ids)
|
||||
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
|
||||
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
|
||||
assert (
|
||||
world_size == mesh_shape_numel
|
||||
), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
|
||||
|
||||
|
||||
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
|
||||
'''
|
||||
"""
|
||||
This method is used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
|
||||
'''
|
||||
"""
|
||||
# parse the device mesh info
|
||||
physical_devices = device_mesh_info.physical_ids
|
||||
physical_mesh = torch.tensor(physical_devices)
|
||||
@@ -67,13 +70,13 @@ class DeviceMeshManager:
|
||||
Args:
|
||||
name (str): name of the device mesh
|
||||
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
|
||||
"""
|
||||
"""
|
||||
if name not in self.device_mesh_store:
|
||||
device_mesh = initialize_device_mesh(device_mesh_info)
|
||||
self.device_mesh_store[name] = device_mesh
|
||||
return device_mesh
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} already exists.')
|
||||
raise ValueError(f"Device mesh {name} already exists.")
|
||||
|
||||
def get(self, name: str) -> DeviceMesh:
|
||||
"""
|
||||
@@ -88,7 +91,7 @@ class DeviceMeshManager:
|
||||
if name in self.device_mesh_store:
|
||||
return self.device_mesh_store[name]
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} does not exist.')
|
||||
raise ValueError(f"Device mesh {name} does not exist.")
|
||||
|
||||
def destroy(self, name: str) -> None:
|
||||
"""
|
||||
@@ -103,7 +106,7 @@ class DeviceMeshManager:
|
||||
dist.destroy_process_group(pg)
|
||||
del self.device_mesh_store[name]
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} does not exist.')
|
||||
raise ValueError(f"Device mesh {name} does not exist.")
|
||||
|
||||
def destroy_all(self):
|
||||
"""
|
||||
|
@@ -36,12 +36,13 @@ class DistCoordinator(metaclass=SingletonMeta):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
assert dist.is_initialized(
|
||||
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
|
||||
assert (
|
||||
dist.is_initialized()
|
||||
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
|
||||
self._rank = dist.get_rank()
|
||||
self._world_size = dist.get_world_size()
|
||||
# this is often passed by launchers such as torchrun
|
||||
self._local_rank = os.environ.get('LOCAL_RANK', -1)
|
||||
self._local_rank = os.environ.get("LOCAL_RANK", -1)
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
@@ -59,7 +60,9 @@ class DistCoordinator(metaclass=SingletonMeta):
|
||||
"""
|
||||
Assert that the local rank is set. This is often passed by launchers such as torchrun.
|
||||
"""
|
||||
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
|
||||
assert (
|
||||
self.local_rank >= 0
|
||||
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
|
||||
|
||||
def is_master(self, process_group: ProcessGroup = None) -> bool:
|
||||
"""
|
||||
@@ -183,7 +186,6 @@ class DistCoordinator(metaclass=SingletonMeta):
|
||||
|
||||
# define an inner function
|
||||
def decorator(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_master:
|
||||
|
@@ -19,7 +19,7 @@ class ProcessGroupManager:
|
||||
def __init__(self):
|
||||
self.pg_store = dict()
|
||||
|
||||
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
|
||||
def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
|
||||
"""
|
||||
Get a process group by name. If the process group does not exist, it will be created.
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProcessGroupManager:
|
||||
self.pg_store[name] = pg
|
||||
return pg
|
||||
else:
|
||||
raise ValueError(f'Process group {name} already exists.')
|
||||
raise ValueError(f"Process group {name} already exists.")
|
||||
|
||||
def get(self, name: str) -> ProcessGroup:
|
||||
"""
|
||||
@@ -51,7 +51,7 @@ class ProcessGroupManager:
|
||||
if name in self.pg_store:
|
||||
return self.pg_store[name]
|
||||
else:
|
||||
raise ValueError(f'Process group {name} does not exist.')
|
||||
raise ValueError(f"Process group {name} does not exist.")
|
||||
|
||||
def destroy(self, name: str) -> None:
|
||||
"""
|
||||
@@ -64,7 +64,7 @@ class ProcessGroupManager:
|
||||
dist.destroy_process_group(self.pg_store[name])
|
||||
del self.pg_store[name]
|
||||
else:
|
||||
raise ValueError(f'Process group {name} does not exist.')
|
||||
raise ValueError(f"Process group {name} does not exist.")
|
||||
|
||||
def destroy_all(self) -> None:
|
||||
"""
|
||||
|
@@ -94,7 +94,7 @@ class ProcessGroupMesh:
|
||||
return np.unravel_index(rank, shape)
|
||||
|
||||
@staticmethod
|
||||
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
|
||||
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
|
||||
"""Convert a coordinate to a rank.
|
||||
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
|
||||
with wrap, index out of range would be wrapped around.
|
||||
@@ -141,8 +141,9 @@ class ProcessGroupMesh:
|
||||
return list(self._group_to_ranks[group])
|
||||
|
||||
@staticmethod
|
||||
def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
|
||||
indices_at_axis: List[int]) -> List[Tuple[int, ...]]:
|
||||
def get_coords_along_axis(
|
||||
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
|
||||
) -> List[Tuple[int, ...]]:
|
||||
"""Get coordinates along the given axis.
|
||||
|
||||
Args:
|
||||
@@ -155,13 +156,12 @@ class ProcessGroupMesh:
|
||||
"""
|
||||
coords_in_group = []
|
||||
for idx in indices_at_axis:
|
||||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:])
|
||||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
||||
return coords_in_group
|
||||
|
||||
def create_group_along_axis(self,
|
||||
axis: int,
|
||||
indices_at_axis: Optional[List[int]] = None,
|
||||
backend: Optional[str] = None) -> ProcessGroup:
|
||||
def create_group_along_axis(
|
||||
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||
|
||||
Args:
|
||||
@@ -186,10 +186,9 @@ class ProcessGroupMesh:
|
||||
target_group = group
|
||||
return target_group
|
||||
|
||||
def get_group_along_axis(self,
|
||||
axis: int,
|
||||
indices_at_axis: Optional[List[int]] = None,
|
||||
backend: Optional[str] = None) -> ProcessGroup:
|
||||
def get_group_along_axis(
|
||||
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user