mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Pipeline Middleware] Adapt scheduler for Topo (#2066)
* adapt scheduler for Topo * remoove comment * fix set input Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from torch import nn
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
from rpc_test_utils import rpc_run, parse_args, MLP
|
||||
from functools import partial
|
||||
|
||||
@@ -18,8 +19,12 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
annotated_model = balanced_split_pass(gm, stage_num)
|
||||
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
return list(split_model.children())[pp_rank]
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
return split_submodules[pp_rank+1]
|
||||
|
||||
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
|
Reference in New Issue
Block a user