[autoparallel] apply repeat block to reduce solving time (#2912)

This commit is contained in:
YuliangLiu0306
2023-02-28 11:03:30 +08:00
committed by GitHub
parent a848091141
commit 197d0bf4ed
6 changed files with 57 additions and 28 deletions

View File

@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]

View File

@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(

View File

@@ -51,15 +51,14 @@ def test_cost_graph():
# return fc
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph)
ret = solver.call_solver_serialized_args()
print(ret[0])