[rpc] split with dag (#2028)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

* add DAG middleware in scheduler

* add test case for scheduler

* remove break

* recover old lifecycle

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-11-29 11:36:28 +08:00
committed by GitHub
parent 96134e7be3
commit b0936e4a44
5 changed files with 337 additions and 53 deletions

View File

@@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
return gm
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
# TODO(lyl): use partition IR to assign partition ID to each node.
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
@@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
part_idx += 1
return part_idx
split_mod = split_module(annotated_gm, None, split_callback)
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):

View File

@@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
is_output = False
def find_output(def_node, output_node):
nonlocal is_output
if def_node == output_node:
is_output = True
if output_partitions is not None:
output_node = output_partitions[0]
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
if is_output:
user_partition_names.append('MODEL_OUTPUT')
if node.op == output_node.op:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names