[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -40,14 +40,16 @@ class DeviceMesh:
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
def __init__(self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
device: str = 'cuda'):
def __init__(
self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
device: str = "cuda",
):
# ============================
# Physical & Logical Mesh IDs
# ============================
@@ -57,9 +59,10 @@ class DeviceMesh:
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
assert mesh_shape is None or logical_mesh_id is None, (
"Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
)
if logical_mesh_id is None:
self._mesh_shape = mesh_shape
@@ -71,12 +74,15 @@ class DeviceMesh:
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
"Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
assert torch.equal(
torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert (
torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert (
torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
@@ -92,8 +98,9 @@ class DeviceMesh:
self.mesh_beta = tuple(mesh_beta)
# ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(self.mesh_beta), \
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
assert len(self.mesh_alpha) == len(
self.mesh_beta
), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
@@ -109,8 +116,9 @@ class DeviceMesh:
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
tensor=self.logical_mesh_id)
self._init_global_to_logical_rank_mapping(
mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
)
# create process group
self._process_group_dict = {}
@@ -194,8 +202,9 @@ class DeviceMesh:
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
assert all([device == device_list[0] for device in device_list]), \
"All devices should be the same, please check your input process groups are created with the same distributed backend."
assert all(
[device == device_list[0] for device in device_list]
), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
@@ -270,7 +279,7 @@ class DeviceMesh:
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != '_process_group_dict':
if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
@@ -278,10 +287,9 @@ class DeviceMesh:
setattr(result, k, v)
return result
def _init_global_to_logical_rank_mapping(self,
mapping: Dict,
tensor: torch.Tensor,
index_list: List[int] = []) -> Dict[int, List[int]]:
def _init_global_to_logical_rank_mapping(
self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
) -> Dict[int, List[int]]:
"""
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
@@ -311,15 +319,19 @@ class DeviceMesh:
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self):
'''
"""
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
"""
# sanity check
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
assert (
dist.is_initialized
), "The torch.distributed should be initialized before calling init_logical_process_group"
assert (
not self._is_initialized
), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
@@ -389,7 +401,7 @@ class DeviceMesh:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
'''
"""
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
@@ -414,7 +426,7 @@ class DeviceMesh:
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
'''
"""
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
@@ -437,7 +449,6 @@ class DeviceMesh:
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
@@ -478,29 +489,37 @@ class DeviceMesh:
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group)
return DeviceMesh(
self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group,
)
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.1)
return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
0.01)
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ 0.01
)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.001)
return (
self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ 0.001
)