[Pipeline Middleware] Adapt scheduler for Topo (#2066)

* adapt scheduler for Topo

* remoove comment

* fix set input

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-05 20:23:41 +08:00
committed by GitHub
parent b3b89865e2
commit 597cdd3006
4 changed files with 160 additions and 128 deletions

View File

@@ -4,6 +4,7 @@ from torch import nn
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from rpc_test_utils import rpc_run, parse_args, MLP
from functools import partial
@@ -18,8 +19,12 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = balanced_split_pass(gm, stage_num)
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True)
return list(split_model.children())[pp_rank]
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return split_submodules[pp_rank+1]
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)