mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -7,7 +7,7 @@ import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
@@ -96,12 +96,12 @@ class MetaInfoProp:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
meta_info = node.best_strategy_info
|
||||
meta_info: ShardMetaInfo
|
||||
|
||||
# set data_ptr for input_tensor in MetaInfo class
|
||||
# set data_ptr for input_tensor in ShardMetaInfo class
|
||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||
|
Reference in New Issue
Block a user