mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[autoparallel] apply repeat block to reduce solving time (#2912)
This commit is contained in:
@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 32
|
||||
HIDDEN_DIM = 768
|
||||
HIDDEN_DIM = 384
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||
def test_self_attention_block(model_cls):
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
|
||||
if model_cls == GPT2MLP:
|
||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
||||
else:
|
||||
@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
print(gm.graph)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
strategies_list = solver.last_s_val
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
|
@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||
# solution construction
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
|
@@ -51,15 +51,14 @@ def test_cost_graph():
|
||||
# return fc
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
print(ret[0])
|
||||
|
Reference in New Issue
Block a user