mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[rpc] split with dag (#2028)
* add DAG to split_module * add comment * add test case for DAG * remove print * add DAG middleware in scheduler * add test case for scheduler * remove break * recover old lifecycle Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
return gm
|
||||
|
||||
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
|
||||
# TODO(lyl): use partition IR to assign partition ID to each node.
|
||||
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
|
||||
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
|
||||
@@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
split_mod = split_module(annotated_gm, None, split_callback)
|
||||
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
|
||||
split_submodules = []
|
||||
for name, submodule in split_mod.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
|
@@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
# find user with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.args:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
is_output = False
|
||||
def find_output(def_node, output_node):
|
||||
nonlocal is_output
|
||||
if def_node == output_node:
|
||||
is_output = True
|
||||
|
||||
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
|
||||
|
||||
if is_output:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
if node.op == output_node.op:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
|
||||
if len(user_partition_names) > 0:
|
||||
return user_partition_names
|
||||
|
Reference in New Issue
Block a user