[Pipeline Middleware] Adapt scheduler for Topo (#2066)

* adapt scheduler for Topo

* remoove comment

* fix set input

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-05 20:23:41 +08:00
committed by GitHub
parent b3b89865e2
commit 597cdd3006
4 changed files with 160 additions and 128 deletions

View File

@@ -108,7 +108,7 @@ def get_topology(gm: GraphModule):
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition(partition_id=0)
topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
@@ -140,6 +140,6 @@ def get_topology(gm: GraphModule):
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition(partition_id=1)
topo.set_output_partition_id(partition_id=1)
return topo