[rpc] split with dag (#2028)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

* add DAG middleware in scheduler

* add test case for scheduler

* remove break

* recover old lifecycle

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-11-29 11:36:28 +08:00
committed by GitHub
parent 96134e7be3
commit b0936e4a44
5 changed files with 337 additions and 53 deletions

View File

@@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
is_output = False
def find_output(def_node, output_node):
nonlocal is_output
if def_node == output_node:
is_output = True
if output_partitions is not None:
output_node = output_partitions[0]
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
if is_output:
user_partition_names.append('MODEL_OUTPUT')
if node.op == output_node.op:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names