[Pipeline] Add Topo Class (#2059)

* use Topo class to rewrite DAG

* polish code

* polish code

* polish code

* add comment

* add else to unended if

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-02 18:13:20 +08:00
committed by GitHub
parent e4293e5077
commit 44ea461890
10 changed files with 451 additions and 283 deletions

View File

@@ -3,7 +3,6 @@ from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
from colossalai.fx.passes.utils import get_DAG
import inspect
@@ -294,11 +293,5 @@ def split_module(
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
DAG = get_DAG(new_gm)
for _, submodule in new_gm.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_DAG', DAG)
return new_gm