[pipeline] Add Simplified Alpa DP Partition (#2507)

* add alpa dp split

* add alpa dp split

* use fwd+bwd instead of fwd only

---------

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2023-03-07 10:34:31 +08:00
committed by GitHub
parent b42d3d28ed
commit 400f63012e
4 changed files with 197 additions and 15 deletions

View File

@@ -1,4 +1,6 @@
import numpy as np
import torch
import tqdm
from torch.fx import symbolic_trace
from torch.fx.node import Node
@@ -9,6 +11,165 @@ def pipe_split():
pass
def block_split():
pass
# Construct blocks with the condition that (block_flops / total_flops) >= limit.
def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
total_fwd_flop = 0
total_bwd_flop = 0
for node in gm.graph.nodes:
total_fwd_flop += node.fwd_flop
total_bwd_flop += node.bwd_flop
total_flop = total_fwd_flop + total_bwd_flop
per_block_flop = total_flop * limit
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
if 'block_split' in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
block_node = gm.graph.create_node('call_function', block_split)
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
return block_nodes
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if (node.op, node.target) == ('call_function', block_split):
gm.graph.erase_node(node)
def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
return all_compute_cost
def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost):
"""The core implementation of the DP algorithm."""
# Adapted from Alpa DP Formulation.
# For f, node ID start from 0
# f[number of stages,
# node id that is currently being considered]
# record time cost(assess by fwd+bwd flop now)
f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32)
# record max stage compute cost among all stages in this partition.
f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32)
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
best_total_cost = f[num_stages, 0]
if np.isinf(best_total_cost):
return np.inf, None
total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0]
current_s = num_stages
current_node = 0
res = []
while current_s > 0 and current_node < num_nodes:
next_start_node = f_argmin[current_s, current_node]
res.append((current_node, next_start_node))
current_s -= 1
current_node = next_start_node
return total_cost, res
def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int):
# Ignore the memory cost profiling in Alpa's design for convenience.
max_compute_costs = np.sort(np.unique(compute_costs))
best_cost = np.inf
best_solution = None
last_max_compute_cost = 0.0
gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space.
if max_compute_cost * num_microbatches >= best_cost:
break
if max_compute_cost - last_max_compute_cost < gap:
continue
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
max_compute_cost)
if cost < best_cost:
best_cost = cost
best_solution = solution
last_max_compute_cost = max_compute_cost
return best_cost, best_solution
# Auto DP partition based on Alpa.
# Adapted to Gpipe Scheduler
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
assert mode in ['node', 'block']
# nodes or blocks will be used in partition.
node_list = []
if mode == 'node':
for node in gm.graph.nodes:
node_list.append(node)
elif mode == 'block':
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
compute_costs = get_compute_costs(node_list)
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
for (_, next_start_node) in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
split_node = gm.graph.create_node('call_function', pipe_split)
pp_size -= 1
# remove block node if possible
if mode == 'block':
remove_blocks(gm)
gm.recompile()
return gm
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgcompute_split_pass, we split module by the fwd flops.