[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -6,6 +6,10 @@ import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
@@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
def transform_to_sharded_model(gm: ColoGraphModule,
meta_args: Dict,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
@@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule,
strategies_constructor,
overlap=overlap)
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
@@ -243,10 +247,13 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer = ColoTracer(trace_act_ckpt=True)
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
strategies_constructor = build_strategy_constructor(graph,
@@ -261,7 +268,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
overlap)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
if return_solution: