[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -7,7 +7,7 @@ import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
@@ -96,12 +96,12 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
# set data_ptr for input_tensor in MetaInfo class
# set data_ptr for input_tensor in ShardMetaInfo class
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out