mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[pipeline] Add Simplified Alpa DP Partition (#2507)
* add alpa dp split * add alpa dp split * use fwd+bwd instead of fwd only --------- Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
@@ -9,6 +11,165 @@ def pipe_split():
|
||||
pass
|
||||
|
||||
|
||||
def block_split():
|
||||
pass
|
||||
|
||||
|
||||
# Construct blocks with the condition that (block_flops / total_flops) >= limit.
|
||||
def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
|
||||
total_fwd_flop = 0
|
||||
total_bwd_flop = 0
|
||||
for node in gm.graph.nodes:
|
||||
total_fwd_flop += node.fwd_flop
|
||||
total_bwd_flop += node.bwd_flop
|
||||
|
||||
total_flop = total_fwd_flop + total_bwd_flop
|
||||
per_block_flop = total_flop * limit
|
||||
accumulate_fwd_flop = 0
|
||||
accumulate_bwd_flop = 0
|
||||
block_nodes = []
|
||||
for node in gm.graph.nodes:
|
||||
if 'block_split' in node.name:
|
||||
continue
|
||||
accumulate_fwd_flop += node.fwd_flop
|
||||
accumulate_bwd_flop += node.bwd_flop
|
||||
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
|
||||
with gm.graph.inserting_after(node):
|
||||
block_node = gm.graph.create_node('call_function', block_split)
|
||||
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
|
||||
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
|
||||
accumulate_fwd_flop = 0
|
||||
accumulate_bwd_flop = 0
|
||||
block_nodes.append(block_node)
|
||||
|
||||
return block_nodes
|
||||
|
||||
|
||||
def remove_blocks(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if (node.op, node.target) == ('call_function', block_split):
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
def get_compute_costs(node_list):
|
||||
num_nodes = len(node_list)
|
||||
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
|
||||
|
||||
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
|
||||
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
|
||||
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
|
||||
all_compute_cost[start, end] = sum(selected_flops)
|
||||
|
||||
return all_compute_cost
|
||||
|
||||
|
||||
def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost):
|
||||
"""The core implementation of the DP algorithm."""
|
||||
# Adapted from Alpa DP Formulation.
|
||||
# For f, node ID start from 0
|
||||
# f[number of stages,
|
||||
# node id that is currently being considered]
|
||||
|
||||
# record time cost(assess by fwd+bwd flop now)
|
||||
f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32)
|
||||
|
||||
# record max stage compute cost among all stages in this partition.
|
||||
f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32)
|
||||
# record start node index for next stage in this partition
|
||||
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
|
||||
f[0, num_nodes] = 0
|
||||
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
|
||||
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
|
||||
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
|
||||
stage_cost = compute_costs[i, k - 1]
|
||||
new_cost = f[s - 1, k] + stage_cost
|
||||
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
|
||||
f[s, i] = new_cost
|
||||
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
|
||||
f_argmin[s, i] = k
|
||||
|
||||
best_total_cost = f[num_stages, 0]
|
||||
if np.isinf(best_total_cost):
|
||||
return np.inf, None
|
||||
|
||||
total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0]
|
||||
|
||||
current_s = num_stages
|
||||
current_node = 0
|
||||
|
||||
res = []
|
||||
while current_s > 0 and current_node < num_nodes:
|
||||
next_start_node = f_argmin[current_s, current_node]
|
||||
res.append((current_node, next_start_node))
|
||||
current_s -= 1
|
||||
current_node = next_start_node
|
||||
|
||||
return total_cost, res
|
||||
|
||||
|
||||
def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int):
|
||||
# Ignore the memory cost profiling in Alpa's design for convenience.
|
||||
max_compute_costs = np.sort(np.unique(compute_costs))
|
||||
best_cost = np.inf
|
||||
best_solution = None
|
||||
last_max_compute_cost = 0.0
|
||||
gap = 1e6 # temporary magic number, unit: flops
|
||||
|
||||
for max_compute_cost in tqdm.tqdm(max_compute_costs):
|
||||
# Pruning to reduce search space.
|
||||
if max_compute_cost * num_microbatches >= best_cost:
|
||||
break
|
||||
if max_compute_cost - last_max_compute_cost < gap:
|
||||
continue
|
||||
|
||||
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
|
||||
max_compute_cost)
|
||||
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
best_solution = solution
|
||||
last_max_compute_cost = max_compute_cost
|
||||
return best_cost, best_solution
|
||||
|
||||
|
||||
# Auto DP partition based on Alpa.
|
||||
# Adapted to Gpipe Scheduler
|
||||
# split_mode:
|
||||
# 'node': fx_node
|
||||
# 'block': many fx_nodes construct a block
|
||||
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
|
||||
assert mode in ['node', 'block']
|
||||
|
||||
# nodes or blocks will be used in partition.
|
||||
node_list = []
|
||||
if mode == 'node':
|
||||
for node in gm.graph.nodes:
|
||||
node_list.append(node)
|
||||
elif mode == 'block':
|
||||
node_list = construct_blocks(gm, limit=block_limit)
|
||||
else:
|
||||
pass
|
||||
|
||||
compute_costs = get_compute_costs(node_list)
|
||||
|
||||
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
|
||||
|
||||
for (_, next_start_node) in best_solution:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
node = node_list[next_start_node]
|
||||
with gm.graph.inserting_before(node):
|
||||
split_node = gm.graph.create_node('call_function', pipe_split)
|
||||
pp_size -= 1
|
||||
|
||||
# remove block node if possible
|
||||
if mode == 'block':
|
||||
remove_blocks(gm)
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In avgcompute_split_pass, we split module by the fwd flops.
|
||||
|
Reference in New Issue
Block a user