mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .alpha_beta_profiler import AlphaBetaProfiler
|
||||
from .calc_pipeline_strategy import alpa_dp
|
||||
|
||||
__all__ = ['AlphaBetaProfiler', 'alpa_dp']
|
||||
__all__ = ["AlphaBetaProfiler", "alpa_dp"]
|
||||
|
@@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0
|
||||
|
||||
|
||||
class AlphaBetaProfiler:
|
||||
'''
|
||||
"""
|
||||
Profile alpha and beta value for a given device list.
|
||||
|
||||
Usage:
|
||||
@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
|
||||
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
|
||||
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
|
||||
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
physical_devices: List[int],
|
||||
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
|
||||
ctype: str = 'a',
|
||||
warmup: int = 5,
|
||||
repeat: int = 25,
|
||||
latency_iters: int = 5,
|
||||
homogeneous_tolerance: float = 0.1):
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
physical_devices: List[int],
|
||||
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
|
||||
ctype: str = "a",
|
||||
warmup: int = 5,
|
||||
repeat: int = 25,
|
||||
latency_iters: int = 5,
|
||||
homogeneous_tolerance: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
physical_devices: A list of device id, each element inside it is the global rank of that device.
|
||||
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
|
||||
@@ -45,7 +47,7 @@ class AlphaBetaProfiler:
|
||||
warmup: Number of warmup iterations.
|
||||
repeat: Number of iterations to measure.
|
||||
latency_iters: Number of iterations to measure latency.
|
||||
'''
|
||||
"""
|
||||
self.physical_devices = physical_devices
|
||||
self.ctype = ctype
|
||||
self.world_size = len(physical_devices)
|
||||
@@ -123,7 +125,7 @@ class AlphaBetaProfiler:
|
||||
return (None, None)
|
||||
|
||||
def profile_latency(self, process_group, pg_handler):
|
||||
'''
|
||||
"""
|
||||
This function is used to profile the latency of the given process group with a series of bytes.
|
||||
|
||||
Args:
|
||||
@@ -132,7 +134,7 @@ class AlphaBetaProfiler:
|
||||
|
||||
Returns:
|
||||
latency: None if the latency is not measured, otherwise the median of the latency_list.
|
||||
'''
|
||||
"""
|
||||
latency_list = []
|
||||
for i in range(self.latency_iters):
|
||||
nbytes = int(BYTE << i)
|
||||
@@ -148,26 +150,26 @@ class AlphaBetaProfiler:
|
||||
return latency
|
||||
|
||||
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
|
||||
'''
|
||||
"""
|
||||
This function is used to profile the bandwidth of the given process group.
|
||||
|
||||
Args:
|
||||
process_group: A tuple of global rank of the process group.
|
||||
pg_handler: The handler of the process group.
|
||||
'''
|
||||
"""
|
||||
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
|
||||
return bandwidth
|
||||
|
||||
def profile_ab(self):
|
||||
'''
|
||||
"""
|
||||
This method is used to profiling the alpha and beta value for a given device list.
|
||||
|
||||
Returns:
|
||||
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
|
||||
'''
|
||||
"""
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
|
||||
rank = dist.get_rank()
|
||||
global_pg_handler = dist.new_group(self.physical_devices)
|
||||
dist.new_group(self.physical_devices)
|
||||
|
||||
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
|
||||
assert rank in process_group
|
||||
@@ -208,7 +210,7 @@ class AlphaBetaProfiler:
|
||||
return alpha_beta_dict
|
||||
|
||||
def search_best_logical_mesh(self):
|
||||
'''
|
||||
"""
|
||||
This method is used to search the best logical mesh for the given device list.
|
||||
|
||||
The best logical mesh is searched in following steps:
|
||||
@@ -232,19 +234,19 @@ class AlphaBetaProfiler:
|
||||
>>> best_logical_mesh = profiler.search_best_logical_mesh()
|
||||
>>> print(best_logical_mesh)
|
||||
[[0, 1], [2, 3]]
|
||||
'''
|
||||
"""
|
||||
|
||||
def _power_of_two(integer):
|
||||
return integer & (integer - 1) == 0
|
||||
|
||||
def _detect_homogeneous_device(alpha_beta_dict):
|
||||
'''
|
||||
"""
|
||||
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
|
||||
|
||||
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
|
||||
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
|
||||
* base_beta.
|
||||
'''
|
||||
"""
|
||||
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
|
||||
for process_group, (_, beta) in alpha_beta_dict.items():
|
||||
if homogeneous_device_dict is None:
|
||||
@@ -254,7 +256,8 @@ class AlphaBetaProfiler:
|
||||
match_beta = None
|
||||
for beta_value in homogeneous_device_dict.keys():
|
||||
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
|
||||
1 - self.homogeneous_tolerance):
|
||||
1 - self.homogeneous_tolerance
|
||||
):
|
||||
match_beta = beta_value
|
||||
break
|
||||
|
||||
@@ -267,9 +270,9 @@ class AlphaBetaProfiler:
|
||||
return homogeneous_device_dict
|
||||
|
||||
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
|
||||
'''
|
||||
"""
|
||||
This function is used to check whether the homogeneous_group contains all physical devices.
|
||||
'''
|
||||
"""
|
||||
flatten_mesh = []
|
||||
for process_group in homogeneous_group:
|
||||
flatten_mesh.extend(process_group)
|
||||
@@ -277,9 +280,9 @@ class AlphaBetaProfiler:
|
||||
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
|
||||
|
||||
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
|
||||
'''
|
||||
"""
|
||||
This function is used to construct the largest ring in the homogeneous_group for each rank.
|
||||
'''
|
||||
"""
|
||||
# Construct the ring
|
||||
ring = []
|
||||
ranks_in_ring = []
|
||||
@@ -300,7 +303,9 @@ class AlphaBetaProfiler:
|
||||
check_rank = check_rank_list.pop()
|
||||
for process_group in homogeneous_group:
|
||||
if check_rank in process_group:
|
||||
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
|
||||
rank_to_append = (
|
||||
process_group[0] if process_group[1] == check_rank else process_group[1]
|
||||
)
|
||||
if rank_to_append not in ring_for_rank:
|
||||
stable_status = False
|
||||
rank_to_check_list.append(rank_to_append)
|
||||
@@ -314,7 +319,7 @@ class AlphaBetaProfiler:
|
||||
assert _power_of_two(self.world_size)
|
||||
power_of_two = int(math.log2(self.world_size))
|
||||
median = power_of_two // 2
|
||||
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
|
||||
balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
|
||||
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
|
||||
balanced_logical_mesh = []
|
||||
for row_index in range(row_size):
|
||||
@@ -348,7 +353,7 @@ class AlphaBetaProfiler:
|
||||
return best_logical_mesh
|
||||
|
||||
def extract_alpha_beta_for_device_mesh(self):
|
||||
'''
|
||||
"""
|
||||
Extract the mesh_alpha list and mesh_beta list based on the
|
||||
best logical mesh, which will be used to initialize the device mesh.
|
||||
|
||||
@@ -360,7 +365,7 @@ class AlphaBetaProfiler:
|
||||
[2.5917552411556242e-05, 0.00010312341153621673]
|
||||
>>> print(mesh_beta)
|
||||
[5.875573704655635e-11, 4.7361584445959614e-12]
|
||||
'''
|
||||
"""
|
||||
best_logical_mesh = self.search_best_logical_mesh()
|
||||
|
||||
first_axis = [row[0] for row in best_logical_mesh]
|
||||
|
@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
|
||||
while i <= num_devices_per_host:
|
||||
i *= 2
|
||||
p += 1
|
||||
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
|
||||
f"while now num_devices_per_host = {num_devices_per_host}")
|
||||
assert pow(2, p) == num_devices_per_host, (
|
||||
"Only supports the cases where num_devices_per_host is power of two, "
|
||||
f"while now num_devices_per_host = {num_devices_per_host}"
|
||||
)
|
||||
if mode == "alpa":
|
||||
for i in range(p + 1):
|
||||
submesh_choices.append((1, pow(2, i)))
|
||||
@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
|
||||
return submesh_choices
|
||||
|
||||
|
||||
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
|
||||
best_configs):
|
||||
def alpa_dp_impl(
|
||||
num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
|
||||
):
|
||||
"""Implementation of Alpa DP for pipeline strategy
|
||||
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
|
||||
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
|
||||
|
||||
Arguments:
|
||||
num_layers: K
|
||||
num_devices: N*M
|
||||
num_microbatches: B
|
||||
submesh_choices: List[(n_i,m_i)]
|
||||
compute_cost: t_intra
|
||||
"""
|
||||
Arguments:
|
||||
num_layers: K
|
||||
num_devices: N*M
|
||||
num_microbatches: B
|
||||
submesh_choices: List[(n_i,m_i)]
|
||||
compute_cost: t_intra
|
||||
"""
|
||||
# For f, layer ID start from 0
|
||||
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
|
||||
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
|
||||
@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
|
||||
for i in range(num_layers, k, -1):
|
||||
stage_cost = compute_cost[k, i, m]
|
||||
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
|
||||
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
|
||||
if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
|
||||
f[s, k, d] = new_cost
|
||||
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
|
||||
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
|
||||
@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
|
||||
|
||||
res = []
|
||||
while current_s > 0 and current_layer < num_layers and current_devices > 0:
|
||||
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
|
||||
next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
|
||||
assert next_start_layer != -1 and current_devices != -1
|
||||
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
|
||||
current_s -= 1
|
||||
current_layer = next_start_layer
|
||||
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
|
||||
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
|
||||
assert current_s == 0 and current_layer == num_layers and current_devices == 0
|
||||
|
||||
return total_cost, res
|
||||
|
||||
|
||||
def alpa_dp(num_layers,
|
||||
num_devices,
|
||||
num_microbatches,
|
||||
submesh_choices,
|
||||
num_autosharding_configs,
|
||||
compute_cost,
|
||||
gap=1e-6):
|
||||
def alpa_dp(
|
||||
num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
|
||||
):
|
||||
"""Alpa auto stage dynamic programming.
|
||||
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
|
||||
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
|
||||
|
||||
Arguments:
|
||||
submesh_choices: List[(int,int)]
|
||||
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
|
||||
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
|
||||
"""
|
||||
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
|
||||
num_autosharding_configs), "Cost shape wrong."
|
||||
assert np.shape(compute_cost) == (
|
||||
num_layers,
|
||||
num_layers,
|
||||
len(submesh_choices),
|
||||
num_autosharding_configs,
|
||||
), "Cost shape wrong."
|
||||
all_possible_stage_costs = np.sort(np.unique(compute_cost))
|
||||
best_cost = np.inf
|
||||
best_solution = None
|
||||
@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
|
||||
break
|
||||
if max_stage_cost - last_max_stage_cost < gap:
|
||||
continue
|
||||
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
|
||||
max_stage_cost, best_configs)
|
||||
cost, solution = alpa_dp_impl(
|
||||
num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
|
||||
)
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
best_solution = solution
|
||||
|
@@ -40,14 +40,16 @@ class DeviceMesh:
|
||||
|
||||
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
|
||||
|
||||
def __init__(self,
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
device: str = 'cuda'):
|
||||
def __init__(
|
||||
self,
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# ============================
|
||||
# Physical & Logical Mesh IDs
|
||||
# ============================
|
||||
@@ -57,9 +59,10 @@ class DeviceMesh:
|
||||
# logical mesh ids can be obtained via two ways
|
||||
# 1. provide physical mesh id and provide mesh shape
|
||||
# 2. directly supply the logical mesh id
|
||||
assert mesh_shape is None or logical_mesh_id is None, \
|
||||
"Only one of mesh_shape and logical_mesh_id can be specified." \
|
||||
assert mesh_shape is None or logical_mesh_id is None, (
|
||||
"Only one of mesh_shape and logical_mesh_id can be specified."
|
||||
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
|
||||
)
|
||||
|
||||
if logical_mesh_id is None:
|
||||
self._mesh_shape = mesh_shape
|
||||
@@ -71,12 +74,15 @@ class DeviceMesh:
|
||||
# ensure two things:
|
||||
# 1. logical and physical mesh IDs should contain the same elements
|
||||
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
|
||||
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
|
||||
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
||||
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
|
||||
"Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
||||
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
|
||||
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
||||
assert torch.equal(
|
||||
torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
|
||||
), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
||||
assert (
|
||||
torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
|
||||
), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
||||
assert (
|
||||
torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
|
||||
), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
||||
|
||||
# ===============================================
|
||||
# coefficient for alpha-beta communication model
|
||||
@@ -92,8 +98,9 @@ class DeviceMesh:
|
||||
self.mesh_beta = tuple(mesh_beta)
|
||||
|
||||
# ensure the alpha and beta have the same shape
|
||||
assert len(self.mesh_alpha) == len(self.mesh_beta), \
|
||||
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
|
||||
assert len(self.mesh_alpha) == len(
|
||||
self.mesh_beta
|
||||
), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
|
||||
|
||||
# =========================
|
||||
# Device for Process Group
|
||||
@@ -109,8 +116,9 @@ class DeviceMesh:
|
||||
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
|
||||
# }
|
||||
self._global_to_local_rank_mapping = dict()
|
||||
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
|
||||
tensor=self.logical_mesh_id)
|
||||
self._init_global_to_logical_rank_mapping(
|
||||
mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
|
||||
)
|
||||
|
||||
# create process group
|
||||
self._process_group_dict = {}
|
||||
@@ -194,8 +202,9 @@ class DeviceMesh:
|
||||
device_list = [_get_device_by_backend(pg) for pg in process_group]
|
||||
|
||||
# make sure all devices are the same
|
||||
assert all([device == device_list[0] for device in device_list]), \
|
||||
"All devices should be the same, please check your input process groups are created with the same distributed backend."
|
||||
assert all(
|
||||
[device == device_list[0] for device in device_list]
|
||||
), "All devices should be the same, please check your input process groups are created with the same distributed backend."
|
||||
|
||||
# create a fake physical mesh id
|
||||
# as we only get the process group associated with the current process,
|
||||
@@ -270,7 +279,7 @@ class DeviceMesh:
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
if k != '_process_group_dict':
|
||||
if k != "_process_group_dict":
|
||||
setattr(result, k, __import__("copy").deepcopy(v, memo))
|
||||
else:
|
||||
# process group cannot be copied
|
||||
@@ -278,10 +287,9 @@ class DeviceMesh:
|
||||
setattr(result, k, v)
|
||||
return result
|
||||
|
||||
def _init_global_to_logical_rank_mapping(self,
|
||||
mapping: Dict,
|
||||
tensor: torch.Tensor,
|
||||
index_list: List[int] = []) -> Dict[int, List[int]]:
|
||||
def _init_global_to_logical_rank_mapping(
|
||||
self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
|
||||
) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
|
||||
|
||||
@@ -311,15 +319,19 @@ class DeviceMesh:
|
||||
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
|
||||
|
||||
def init_logical_process_group(self):
|
||||
'''
|
||||
"""
|
||||
This method is used to initialize the logical process groups which will be used in communications
|
||||
among logical device mesh.
|
||||
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
||||
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
||||
'''
|
||||
"""
|
||||
# sanity check
|
||||
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
|
||||
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
|
||||
assert (
|
||||
dist.is_initialized
|
||||
), "The torch.distributed should be initialized before calling init_logical_process_group"
|
||||
assert (
|
||||
not self._is_initialized
|
||||
), "The logical process group has been initialized, do not call init_logical_process_group twice"
|
||||
|
||||
# update the global rank of the current process
|
||||
self._global_rank_of_current_process = dist.get_rank()
|
||||
@@ -389,7 +401,7 @@ class DeviceMesh:
|
||||
return local_ranks
|
||||
|
||||
def _collate_global_ranks_in_same_process_group(self, global_rank):
|
||||
'''
|
||||
"""
|
||||
Give a global rank and return all global ranks involved in its associated process group in each axis.
|
||||
|
||||
Example:
|
||||
@@ -414,7 +426,7 @@ class DeviceMesh:
|
||||
0: [0, 4, 8, 12],
|
||||
1: [0, 1, 2, 3]
|
||||
# }
|
||||
'''
|
||||
"""
|
||||
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
|
||||
# for self._global_to_local_rank_mapping
|
||||
# the key is the global rank
|
||||
@@ -437,7 +449,6 @@ class DeviceMesh:
|
||||
# in the same process group in the given axis
|
||||
# the _local_rank refers to the local rank of the current process
|
||||
for _local_rank in range(self.logical_mesh_id.shape[dim]):
|
||||
|
||||
# if this dimension is not initialized yet,
|
||||
# initialize it with an empty array
|
||||
if dim not in processes_in_the_same_process_group:
|
||||
@@ -478,29 +489,37 @@ class DeviceMesh:
|
||||
|
||||
flatten_mesh_shape_size = len(self._mesh_shape)
|
||||
flatten_mesh_shape = [self.num_devices]
|
||||
return DeviceMesh(self._physical_mesh_id,
|
||||
tuple(flatten_mesh_shape),
|
||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
init_process_group=self._init_process_group)
|
||||
return DeviceMesh(
|
||||
self._physical_mesh_id,
|
||||
tuple(flatten_mesh_shape),
|
||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
init_process_group=self._init_process_group,
|
||||
)
|
||||
|
||||
def all_gather_cost(self, num_bytes, mesh_dim):
|
||||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
|
||||
0.1)
|
||||
return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
|
||||
|
||||
def all_reduce_cost(self, num_bytes, mesh_dim):
|
||||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
|
||||
0.01)
|
||||
return (
|
||||
self.mesh_alpha[mesh_dim]
|
||||
+ self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
|
||||
+ 0.01
|
||||
)
|
||||
|
||||
def reduce_scatter_cost(self, num_bytes, mesh_dim):
|
||||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
|
||||
0.001)
|
||||
return (
|
||||
self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
|
||||
)
|
||||
|
||||
def all_to_all_cost(self, num_bytes, mesh_dim):
|
||||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
penalty_factor = num_devices / 2.0
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
|
||||
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
|
||||
return (
|
||||
self.mesh_alpha[mesh_dim]
|
||||
+ self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
|
||||
+ 0.001
|
||||
)
|
||||
|
Reference in New Issue
Block a user