mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -6,6 +6,10 @@ import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||
@@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
@@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
||||
|
||||
|
||||
def transform_to_sharded_model(gm: ColoGraphModule,
|
||||
meta_args: Dict,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
@@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule,
|
||||
strategies_constructor,
|
||||
overlap=overlap)
|
||||
gm = runtime_apply_pass(gm)
|
||||
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
gm.recompile()
|
||||
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
@@ -243,10 +247,13 @@ def initialize_model(model: nn.Module,
|
||||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
gm.recompile()
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph,
|
||||
@@ -261,7 +268,9 @@ def initialize_model(model: nn.Module,
|
||||
if save_solver_solution:
|
||||
torch.save(solution, solution_path)
|
||||
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
|
||||
overlap)
|
||||
|
||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
if return_solution:
|
||||
|
@@ -2,8 +2,6 @@ from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
|
@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
@@ -258,7 +258,7 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
@@ -266,15 +266,15 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
strategies_info = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
metainfo = ShardMetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
strategies_info.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
setattr(self, "strategies_info", strategies_info)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
@@ -313,7 +313,7 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
@@ -321,15 +321,15 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
strategies_info = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
metainfo = ShardMetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
strategies_info.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
setattr(self, "strategies_info", strategies_info)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
|
@@ -137,9 +137,9 @@ class StrategiesConstructor:
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
@@ -150,9 +150,9 @@ class StrategiesConstructor:
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
@@ -163,9 +163,9 @@ class StrategiesConstructor:
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
|
Reference in New Issue
Block a user