[autoparallel] autoparallel initialize (#2238)

This commit is contained in:
YuliangLiu0306
2022-12-31 01:02:14 +08:00
committed by GitHub
parent 85178a397a
commit 8897b8f753
2 changed files with 260 additions and 38 deletions

View File

@@ -17,6 +17,7 @@ from torch.profiler import ProfilerActivity, profile, record_function, schedule,
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
@@ -80,12 +81,9 @@ def main():
model = GPT2LMHeadModel(config=config).to('cuda')
global_numel = sum([p.numel() for p in model.parameters()])
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_input_sample = {
'input_ids': input_ids.to('meta'),
'attention_mask': attention_mask.to('meta'),
'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
physical_mesh_id = torch.arange(0, 4)
@@ -93,39 +91,8 @@ def main():
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
# solution = [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 13, 8, 9, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 9, 0, 0, 8, 0]
print(solution)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
# *******************strategy selected*******************
print("*******************strategy selected*******************")
strategies_list = solution
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
gm = initialize_model(model, meta_input_sample, device_mesh)
# build criterion
criterion = GPTLMLoss()
@@ -146,7 +113,7 @@ def main():
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = gm(input_ids, attn_mask, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
outputs = gm(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
loss.backward()
optimizer.step()