mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[rpc] split with dag (#2028)
* add DAG to split_module * add comment * add test case for DAG * remove print * add DAG middleware in scheduler * add test case for scheduler * remove break * recover old lifecycle Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
# find user with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.args:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
is_output = False
|
||||
def find_output(def_node, output_node):
|
||||
nonlocal is_output
|
||||
if def_node == output_node:
|
||||
is_output = True
|
||||
|
||||
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
|
||||
|
||||
if is_output:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
if node.op == output_node.op:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
|
||||
if len(user_partition_names) > 0:
|
||||
return user_partition_names
|
||||
|
Reference in New Issue
Block a user