mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[Pipeline] Add Topo Class (#2059)
* use Topo class to rewrite DAG * polish code * polish code * polish code * add comment * add else to unended if Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from typing import Dict, Set
|
||||
from typing import Dict
|
||||
from torch.fx.node import Node, map_arg
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""
|
||||
@@ -171,161 +170,3 @@ def get_node_module(node) -> torch.nn.Module:
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
return module
|
||||
|
||||
def find_def_in_partition(node, partitions, input_partitions=None, direct=False):
|
||||
# find def in input
|
||||
if input_partitions is not None:
|
||||
for placeholder in input_partitions:
|
||||
if placeholder.name == node.name:
|
||||
return 'MODEL_INPUT'
|
||||
|
||||
# find direct def
|
||||
if direct:
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
return partition.name
|
||||
# find def with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.users.keys():
|
||||
return partition.name
|
||||
|
||||
print(f'Not found def in partition {node.name}')
|
||||
return None
|
||||
|
||||
def find_user_in_partition(node, partitions, output_partitions=None, direct=False):
|
||||
user_partition_names = []
|
||||
# find direct user
|
||||
if direct:
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
# find user with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.args:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
if node.op == output_node.op:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
|
||||
if len(user_partition_names) > 0:
|
||||
return user_partition_names
|
||||
|
||||
print(f'Not found user in partition {node.name}')
|
||||
return None
|
||||
|
||||
def get_partition_depends(partition, partitions, input_partitions=None, output_partitions=None):
|
||||
# e.g. Partition2: {input: {Partition0: [sub1_1], Partition1: [sub2_0]}, output:{Output: [sub3_0]}},
|
||||
input = {}
|
||||
output = {}
|
||||
|
||||
for offset, arg in enumerate(partition.args):
|
||||
def_partition_name = None
|
||||
if not arg.name.startswith('getitem'):
|
||||
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=True)
|
||||
else:
|
||||
def_partition_name = find_def_in_partition(arg, partitions, input_partitions, direct=False)
|
||||
if def_partition_name is None:
|
||||
continue
|
||||
if def_partition_name not in input:
|
||||
input[def_partition_name] = []
|
||||
input[def_partition_name].append(offset)
|
||||
|
||||
offset = -1
|
||||
for user in partition.users.keys():
|
||||
user_partition_names = None
|
||||
if input_partitions is None or not user.name.startswith('getitem'):
|
||||
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=True)
|
||||
offset = 0
|
||||
else:
|
||||
user_partition_names = find_user_in_partition(user, partitions, output_partitions, direct=False)
|
||||
offset += 1
|
||||
if user_partition_names is None:
|
||||
continue
|
||||
for user_partition_name in user_partition_names:
|
||||
if user_partition_name not in output:
|
||||
output[user_partition_name] = []
|
||||
output[user_partition_name].append(offset)
|
||||
|
||||
return input, output, offset+1
|
||||
|
||||
# DAG just looks like following case.
|
||||
# the int in every list represents the offset of the partition's input arg or output arg.
|
||||
# {
|
||||
# 'input_partition': {
|
||||
# 'input_ids': {
|
||||
# 'input': {},
|
||||
# 'output': {'submod_0': [0], 'submod_1': [1]},
|
||||
# 'output_len': 0},
|
||||
# 'attention_mask': {
|
||||
# 'input': {},
|
||||
# 'output': {'submod_2': [0]},
|
||||
# 'output_len': 0}},
|
||||
# 'submod_0': {
|
||||
# 'input': {'MODEL_INPUT': [0]},
|
||||
# 'output': {'submod_1': [0], 'submod_2': [0, 1]},
|
||||
# 'output_len': 2},
|
||||
# 'submod_1': {
|
||||
# 'input': {'submod_0': [0], 'MODEL_INPUT': [1]},
|
||||
# 'output': {'submod_2': [0]},
|
||||
# 'output_len': 1},
|
||||
# 'submod_2': {
|
||||
# 'input': {'MODEL_INPUT': [0], 'submod_0': [1, 2]},
|
||||
# 'output': {'submod_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
# 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||
# 22, 23, 24]},
|
||||
# 'output_len': 25},
|
||||
# 'submod_3': {
|
||||
# 'input': {'submod_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
# 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
# 21, 22, 23, 24]},
|
||||
# 'output': {'MODEL_OUTPUT': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
# 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
# 20, 21, 22, 23, 24]},
|
||||
# 'output_len': 25},
|
||||
# 'output_partition': {
|
||||
# 'input': {'logits': 'submod_3', 'past_key_values': (('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'),
|
||||
# ('submod_3', 'submod_3'), ('submod_3', 'submod_3'))},
|
||||
# 'output': {}, 'output_len': 0}
|
||||
# }
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def get_DAG(gm: GraphModule):
|
||||
DAG = {}
|
||||
input_partitions = []
|
||||
partitions = []
|
||||
output_partitions = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_partitions.append(node)
|
||||
elif node.name.startswith('submod_'):
|
||||
partitions.append(node)
|
||||
elif node.op == 'output':
|
||||
output_partitions.append(node)
|
||||
|
||||
for partition in input_partitions:
|
||||
DAG_node = {'input': {}, 'output': {}, 'output_len': 1}
|
||||
_, output, _ = get_partition_depends(partition, partitions, None, output_partitions)
|
||||
DAG_node['output'] = output
|
||||
if 'input_partition' not in DAG:
|
||||
DAG['input_partition'] = {}
|
||||
DAG['input_partition'][partition.name] = DAG_node
|
||||
|
||||
for partition in partitions:
|
||||
DAG_node = {'input': {}, 'output': {}}
|
||||
DAG_node['input'], DAG_node['output'], DAG_node['output_len'] = get_partition_depends(partition, partitions, input_partitions, output_partitions)
|
||||
DAG[partition.name] = DAG_node
|
||||
|
||||
for partition in output_partitions:
|
||||
DAG_node = {'input': {}, 'output': {}, 'output_len': 0}
|
||||
DAG_node['input'] = torch.fx.graph.map_arg(partition.args[0], lambda n: find_def_in_partition(n, partitions, input_partitions))
|
||||
DAG['output_partition'] = DAG_node
|
||||
|
||||
return DAG
|
Reference in New Issue
Block a user