mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
@dataclass
|
||||
class DeviceMeshInfo:
|
||||
'''
|
||||
"""
|
||||
This class is used to store the information used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
|
||||
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
|
||||
'''
|
||||
"""
|
||||
|
||||
physical_ids: List[int]
|
||||
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
|
||||
|
||||
@@ -24,16 +25,18 @@ class DeviceMeshInfo:
|
||||
if self.mesh_shape is not None:
|
||||
world_size = len(self.physical_ids)
|
||||
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
|
||||
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
|
||||
assert (
|
||||
world_size == mesh_shape_numel
|
||||
), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
|
||||
|
||||
|
||||
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
|
||||
'''
|
||||
"""
|
||||
This method is used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
|
||||
'''
|
||||
"""
|
||||
# parse the device mesh info
|
||||
physical_devices = device_mesh_info.physical_ids
|
||||
physical_mesh = torch.tensor(physical_devices)
|
||||
@@ -67,13 +70,13 @@ class DeviceMeshManager:
|
||||
Args:
|
||||
name (str): name of the device mesh
|
||||
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
|
||||
"""
|
||||
"""
|
||||
if name not in self.device_mesh_store:
|
||||
device_mesh = initialize_device_mesh(device_mesh_info)
|
||||
self.device_mesh_store[name] = device_mesh
|
||||
return device_mesh
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} already exists.')
|
||||
raise ValueError(f"Device mesh {name} already exists.")
|
||||
|
||||
def get(self, name: str) -> DeviceMesh:
|
||||
"""
|
||||
@@ -88,7 +91,7 @@ class DeviceMeshManager:
|
||||
if name in self.device_mesh_store:
|
||||
return self.device_mesh_store[name]
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} does not exist.')
|
||||
raise ValueError(f"Device mesh {name} does not exist.")
|
||||
|
||||
def destroy(self, name: str) -> None:
|
||||
"""
|
||||
@@ -103,7 +106,7 @@ class DeviceMeshManager:
|
||||
dist.destroy_process_group(pg)
|
||||
del self.device_mesh_store[name]
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} does not exist.')
|
||||
raise ValueError(f"Device mesh {name} does not exist.")
|
||||
|
||||
def destroy_all(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user