mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[autoparallel] apply repeat block to reduce solving time (#2912)
This commit is contained in:
parent
a848091141
commit
197d0bf4ed
@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
|||||||
This method is used to solve the best solution for the given graph.
|
This method is used to solve the best solution for the given graph.
|
||||||
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
||||||
'''
|
'''
|
||||||
graph_analyser = GraphAnalyser(gm)
|
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
||||||
liveness_list = graph_analyser.liveness_analysis()
|
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
||||||
|
# graph_analyser = GraphAnalyser(gm)
|
||||||
|
# liveness_list = graph_analyser.liveness_analysis()
|
||||||
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
|
||||||
cost_graph.simplify_graph()
|
cost_graph.simplify_graph()
|
||||||
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
|
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
solution = list(ret[0])
|
solution = list(ret[0])
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class Solver:
|
|||||||
graph: Graph,
|
graph: Graph,
|
||||||
strategies_constructor: StrategiesConstructor,
|
strategies_constructor: StrategiesConstructor,
|
||||||
cost_graph: CostGraph,
|
cost_graph: CostGraph,
|
||||||
graph_analyser: GraphAnalyser,
|
graph_analyser: GraphAnalyser = None,
|
||||||
memory_budget: float = -1.0,
|
memory_budget: float = -1.0,
|
||||||
solution_numbers: int = 1,
|
solution_numbers: int = 1,
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
@ -63,7 +63,10 @@ class Solver:
|
|||||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||||
else:
|
else:
|
||||||
self.memory_increasing_coefficient = 1
|
self.memory_increasing_coefficient = 1
|
||||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
||||||
|
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
||||||
|
# self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||||
|
self.liveness_list = self.nodes
|
||||||
self.node_index_dict = self._generate_node_index_dict()
|
self.node_index_dict = self._generate_node_index_dict()
|
||||||
# The last solution vector of auto sharding.
|
# The last solution vector of auto sharding.
|
||||||
self.last_s_val = None
|
self.last_s_val = None
|
||||||
@ -140,7 +143,7 @@ class Solver:
|
|||||||
liveness_set = self.liveness_list
|
liveness_set = self.liveness_list
|
||||||
|
|
||||||
# omit alias_set now
|
# omit alias_set now
|
||||||
alias_set = None
|
alias_set = self.strategies_constructor.alias_set
|
||||||
alias_convert_costs = None
|
alias_convert_costs = None
|
||||||
|
|
||||||
# prepare compute_costs, communication_costs and memory_costs
|
# prepare compute_costs, communication_costs and memory_costs
|
||||||
@ -230,6 +233,7 @@ class Solver:
|
|||||||
|
|
||||||
# 0. Unpack flatten numpy arrays
|
# 0. Unpack flatten numpy arrays
|
||||||
s_follow = following_nodes
|
s_follow = following_nodes
|
||||||
|
s_alias = alias_set
|
||||||
|
|
||||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||||
r = []
|
r = []
|
||||||
@ -294,8 +298,11 @@ class Solver:
|
|||||||
if strategies_len[i] == 1:
|
if strategies_len[i] == 1:
|
||||||
s.append([1])
|
s.append([1])
|
||||||
else:
|
else:
|
||||||
num_nodes += 1
|
if i not in s_alias:
|
||||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
num_nodes += 1
|
||||||
|
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||||
|
else:
|
||||||
|
s.append(s[s_alias[i]])
|
||||||
else:
|
else:
|
||||||
if s_follow[i] < len(s):
|
if s_follow[i] < len(s):
|
||||||
s.append(s[s_follow[i]])
|
s.append(s[s_follow[i]])
|
||||||
@ -311,15 +318,20 @@ class Solver:
|
|||||||
#############################
|
#############################
|
||||||
e = []
|
e = []
|
||||||
num_edges = 0
|
num_edges = 0
|
||||||
|
map_edge_to_idx = {}
|
||||||
for (idx, (i, j)) in enumerate(E):
|
for (idx, (i, j)) in enumerate(E):
|
||||||
if len(s[i]) == 1:
|
if len(s[i]) == 1:
|
||||||
e.append(s[j])
|
e.append(s[j])
|
||||||
elif len(s[j]) == 1:
|
elif len(s[j]) == 1:
|
||||||
e.append(s[i])
|
e.append(s[i])
|
||||||
else:
|
else:
|
||||||
num_edges += 1
|
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
|
||||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
|
||||||
|
else:
|
||||||
|
num_edges += 1
|
||||||
|
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||||
assert len(e[idx]) == len(r[idx])
|
assert len(e[idx]) == len(r[idx])
|
||||||
|
map_edge_to_idx[(i, j)] = idx
|
||||||
for element in s:
|
for element in s:
|
||||||
assert len(element) > 0
|
assert len(element) > 0
|
||||||
# 2. Set initial value
|
# 2. Set initial value
|
||||||
@ -371,13 +383,12 @@ class Solver:
|
|||||||
# compute memory consumption with liveness set #
|
# compute memory consumption with liveness set #
|
||||||
#################################################
|
#################################################
|
||||||
if memory_budget > 0:
|
if memory_budget > 0:
|
||||||
for liveness_stage in liveness_set:
|
mem = 0
|
||||||
mem = 0
|
for node in liveness_set:
|
||||||
for live_variable in liveness_stage.unique_live_vars:
|
if node not in self.node_index_dict:
|
||||||
if live_variable.node not in self.node_index_dict:
|
continue
|
||||||
continue
|
node_index = self.node_index_dict[node]
|
||||||
node_index = self.node_index_dict[live_variable.node]
|
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
|
||||||
prob += mem <= memory_budget
|
prob += mem <= memory_budget
|
||||||
|
|
||||||
# (d). specified by `cat="Binary"`
|
# (d). specified by `cat="Binary"`
|
||||||
|
@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
|
|||||||
)
|
)
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
from ..options import DataloaderOption, SolverOptions
|
from ..options import DataloaderOption, SolverOptions
|
||||||
@ -42,6 +43,7 @@ class StrategiesConstructor:
|
|||||||
self.strategy_map = {}
|
self.strategy_map = {}
|
||||||
self.solver_options = solver_options
|
self.solver_options = solver_options
|
||||||
self.no_strategy_nodes = []
|
self.no_strategy_nodes = []
|
||||||
|
self.alias_set = None
|
||||||
|
|
||||||
def remove_duplicated_strategy(self, strategies_vector):
|
def remove_duplicated_strategy(self, strategies_vector):
|
||||||
'''
|
'''
|
||||||
@ -59,6 +61,22 @@ class StrategiesConstructor:
|
|||||||
for strategy in remove_list:
|
for strategy in remove_list:
|
||||||
strategies_vector.remove(strategy)
|
strategies_vector.remove(strategy)
|
||||||
|
|
||||||
|
def generate_alias_set(self):
|
||||||
|
|
||||||
|
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
|
||||||
|
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
|
||||||
|
|
||||||
|
repeat_block_nums = len(common_blocks)
|
||||||
|
alias_set = {}
|
||||||
|
|
||||||
|
if repeat_block_nums == 0:
|
||||||
|
return alias_set
|
||||||
|
|
||||||
|
for index, common_node in enumerate(common_blocks[0]):
|
||||||
|
for i in range(1, repeat_block_nums):
|
||||||
|
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
|
||||||
|
return alias_set
|
||||||
|
|
||||||
def build_strategies_and_cost(self):
|
def build_strategies_and_cost(self):
|
||||||
"""
|
"""
|
||||||
This method is to build the strategy vector for each node in the computation graph.
|
This method is to build the strategy vector for each node in the computation graph.
|
||||||
@ -175,3 +193,6 @@ class StrategiesConstructor:
|
|||||||
self.leaf_strategies.remove(node.strategies_vector)
|
self.leaf_strategies.remove(node.strategies_vector)
|
||||||
if node in self.strategy_map:
|
if node in self.strategy_map:
|
||||||
self.strategy_map.pop(node)
|
self.strategy_map.pop(node)
|
||||||
|
|
||||||
|
alias_set = self.generate_alias_set()
|
||||||
|
self.alias_set = alias_set
|
||||||
|
@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
|
|||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGTH = 32
|
SEQ_LENGTH = 32
|
||||||
HIDDEN_DIM = 768
|
HIDDEN_DIM = 384
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||||
def test_self_attention_block(model_cls):
|
def test_self_attention_block(model_cls):
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
|
||||||
if model_cls == GPT2MLP:
|
if model_cls == GPT2MLP:
|
||||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
||||||
else:
|
else:
|
||||||
@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
|
|||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
print(gm.graph)
|
print(gm.graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
graph_analyser = GraphAnalyser(gm)
|
|
||||||
liveness_list = graph_analyser.liveness_analysis()
|
|
||||||
solver_options = SolverOptions()
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
cost_graph.simplify_graph()
|
cost_graph.simplify_graph()
|
||||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
strategies_list = solver.last_s_val
|
strategies_list = solver.last_s_val
|
||||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
|
@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
|
|||||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||||||
# solution construction
|
# solution construction
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
cost_graph.simplify_graph()
|
cost_graph.simplify_graph()
|
||||||
graph_analyser = GraphAnalyser(gm)
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
|
||||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
|
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
solution = list(ret[0])
|
solution = list(ret[0])
|
||||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||||
|
@ -51,15 +51,14 @@ def test_cost_graph():
|
|||||||
# return fc
|
# return fc
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
graph_analyser = GraphAnalyser(gm)
|
|
||||||
liveness_list = graph_analyser.liveness_analysis()
|
|
||||||
solver_options = SolverOptions()
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
cost_graph.simplify_graph()
|
cost_graph.simplify_graph()
|
||||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
solver = Solver(gm.graph, strategies_constructor, cost_graph)
|
||||||
|
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
print(ret[0])
|
print(ret[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user