[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -0,0 +1,3 @@
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']

View File

@@ -0,0 +1,153 @@
import torch
from torch.fx.graph_module import GraphModule
from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
partition_id = 0
elif is_output:
partition_id = 1
else:
prefix = 'submod_'
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
# temporary_val = submod1[0]
# submod2 = call_module(temporary_val, ...)
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
# search in input
if direct_def and input_partitions is not None:
partition_id = partition_name_to_id('', is_input=True)
for i, input_node in enumerate(input_partitions):
if input_node == node:
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
return p_input_val
# search submod in mid part
if direct_def:
for partition in partitions:
if partition == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
return p_input_val
# search temporary value in graph
else:
for partition in partitions:
for offset, mid_val in enumerate(partition.users):
if mid_val == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val
return p_input_val
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
direct_use = not user.name.startswith('getitem')
# user is mid partition
for partition in partitions:
# direct call
if direct_use:
if user == partition:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
# getitem call
else:
if user in partition.args:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break
# user is output
if output_partitions is not None:
output_node = output_partitions[0]
if user.op == output_node.op:
output_keys = {}
partition_id = partition_name_to_id('', is_output=True)
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
for i, arg in enumerate(output_keys):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
return p_output_val
def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
else:
continue
# set output for input_partition
topo_input_partition = Partition()
for partition in input_partitions:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
for arg in partition.args:
cur_node = arg
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
topo_mid_partition.add_input_val(p_input_val)
# set output for submodule
direct_use = True
for user in partition.users:
if user.name.startswith('getitem'):
direct_use = False
break
if direct_use:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
else:
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition)
# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
torch.fx.graph.map_arg(
partition.args[0],
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)
return topo