mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[autoparallel] autoparallel initialize (#2238)
This commit is contained in:
@@ -17,6 +17,7 @@ from torch.profiler import ProfilerActivity, profile, record_function, schedule,
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
@@ -80,12 +81,9 @@ def main():
|
||||
model = GPT2LMHeadModel(config=config).to('cuda')
|
||||
global_numel = sum([p.numel() for p in model.parameters()])
|
||||
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
|
||||
meta_input_sample = {
|
||||
'input_ids': input_ids.to('meta'),
|
||||
'attention_mask': attention_mask.to('meta'),
|
||||
'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
|
||||
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
|
||||
}
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
@@ -93,39 +91,8 @@ def main():
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
|
||||
solution = list(ret[0])
|
||||
# solution = [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 13, 8, 9, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 9, 0, 0, 8, 0]
|
||||
print(solution)
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
gm, solution, device_mesh, strategies_constructor)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
# *******************strategy selected*******************
|
||||
print("*******************strategy selected*******************")
|
||||
strategies_list = solution
|
||||
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||
gm = initialize_model(model, meta_input_sample, device_mesh)
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
@@ -146,7 +113,7 @@ def main():
|
||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE)
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
outputs = gm(input_ids, attn_mask, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
outputs = gm(input_ids, attn_mask)
|
||||
loss = criterion(outputs, input_ids)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
Reference in New Issue
Block a user