mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
|
||||
while i <= num_devices_per_host:
|
||||
i *= 2
|
||||
p += 1
|
||||
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
|
||||
f"while now num_devices_per_host = {num_devices_per_host}")
|
||||
assert pow(2, p) == num_devices_per_host, (
|
||||
"Only supports the cases where num_devices_per_host is power of two, "
|
||||
f"while now num_devices_per_host = {num_devices_per_host}"
|
||||
)
|
||||
if mode == "alpa":
|
||||
for i in range(p + 1):
|
||||
submesh_choices.append((1, pow(2, i)))
|
||||
@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
|
||||
return submesh_choices
|
||||
|
||||
|
||||
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
|
||||
best_configs):
|
||||
def alpa_dp_impl(
|
||||
num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
|
||||
):
|
||||
"""Implementation of Alpa DP for pipeline strategy
|
||||
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
|
||||
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
|
||||
|
||||
Arguments:
|
||||
num_layers: K
|
||||
num_devices: N*M
|
||||
num_microbatches: B
|
||||
submesh_choices: List[(n_i,m_i)]
|
||||
compute_cost: t_intra
|
||||
"""
|
||||
Arguments:
|
||||
num_layers: K
|
||||
num_devices: N*M
|
||||
num_microbatches: B
|
||||
submesh_choices: List[(n_i,m_i)]
|
||||
compute_cost: t_intra
|
||||
"""
|
||||
# For f, layer ID start from 0
|
||||
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
|
||||
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
|
||||
@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
|
||||
for i in range(num_layers, k, -1):
|
||||
stage_cost = compute_cost[k, i, m]
|
||||
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
|
||||
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
|
||||
if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
|
||||
f[s, k, d] = new_cost
|
||||
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
|
||||
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
|
||||
@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
|
||||
|
||||
res = []
|
||||
while current_s > 0 and current_layer < num_layers and current_devices > 0:
|
||||
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
|
||||
next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
|
||||
assert next_start_layer != -1 and current_devices != -1
|
||||
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
|
||||
current_s -= 1
|
||||
current_layer = next_start_layer
|
||||
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
|
||||
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
|
||||
assert current_s == 0 and current_layer == num_layers and current_devices == 0
|
||||
|
||||
return total_cost, res
|
||||
|
||||
|
||||
def alpa_dp(num_layers,
|
||||
num_devices,
|
||||
num_microbatches,
|
||||
submesh_choices,
|
||||
num_autosharding_configs,
|
||||
compute_cost,
|
||||
gap=1e-6):
|
||||
def alpa_dp(
|
||||
num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
|
||||
):
|
||||
"""Alpa auto stage dynamic programming.
|
||||
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
|
||||
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
|
||||
|
||||
Arguments:
|
||||
submesh_choices: List[(int,int)]
|
||||
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
|
||||
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
|
||||
"""
|
||||
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
|
||||
num_autosharding_configs), "Cost shape wrong."
|
||||
assert np.shape(compute_cost) == (
|
||||
num_layers,
|
||||
num_layers,
|
||||
len(submesh_choices),
|
||||
num_autosharding_configs,
|
||||
), "Cost shape wrong."
|
||||
all_possible_stage_costs = np.sort(np.unique(compute_cost))
|
||||
best_cost = np.inf
|
||||
best_solution = None
|
||||
@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
|
||||
break
|
||||
if max_stage_cost - last_max_stage_cost < gap:
|
||||
continue
|
||||
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
|
||||
max_stage_cost, best_configs)
|
||||
cost, solution = alpa_dp_impl(
|
||||
num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
|
||||
)
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
best_solution = solution
|
||||
|
Reference in New Issue
Block a user