mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
3
colossalai/legacy/pipeline/middleware/__init__.py
Normal file
3
colossalai/legacy/pipeline/middleware/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
|
||||
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
|
@@ -0,0 +1,3 @@
|
||||
from .fx import get_topology as get_fx_topology
|
||||
|
||||
__all__ = ['get_fx_topology']
|
153
colossalai/legacy/pipeline/middleware/adaptor/fx.py
Normal file
153
colossalai/legacy/pipeline/middleware/adaptor/fx.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
|
||||
|
||||
def partition_name_to_id(partition_name, is_input=False, is_output=False):
|
||||
if is_input:
|
||||
partition_id = 0
|
||||
elif is_output:
|
||||
partition_id = 1
|
||||
else:
|
||||
prefix = 'submod_'
|
||||
partition_id = int(partition_name.split(prefix)[-1]) + 2
|
||||
return partition_id
|
||||
|
||||
|
||||
# There are two kinds of def in fx.graph
|
||||
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
|
||||
# e.g. submod1 = call_module(...)
|
||||
# temporary_val = submod1[0]
|
||||
# submod2 = call_module(temporary_val, ...)
|
||||
# 2. direct_use & direct_def, which means the output is used by next partition directly.
|
||||
# e.g. submod1 = call_module(...)
|
||||
# submod2 = call_module(submod1, ...)
|
||||
|
||||
|
||||
def find_input_in_partition(node, partitions, input_partitions=None):
|
||||
p_input_val = None
|
||||
direct_def = not node.name.startswith('getitem')
|
||||
# search in input
|
||||
if direct_def and input_partitions is not None:
|
||||
partition_id = partition_name_to_id('', is_input=True)
|
||||
for i, input_node in enumerate(input_partitions):
|
||||
if input_node == node:
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
|
||||
return p_input_val
|
||||
# search submod in mid part
|
||||
if direct_def:
|
||||
for partition in partitions:
|
||||
if partition == node:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
|
||||
return p_input_val
|
||||
# search temporary value in graph
|
||||
else:
|
||||
for partition in partitions:
|
||||
for offset, mid_val in enumerate(partition.users):
|
||||
if mid_val == node:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
|
||||
return p_input_val
|
||||
|
||||
return p_input_val
|
||||
|
||||
|
||||
def find_output_in_partition(node, partitions, output_partitions=None):
|
||||
p_output_val = PartitionOutputVal()
|
||||
for user in node.users:
|
||||
direct_use = not user.name.startswith('getitem')
|
||||
# user is mid partition
|
||||
for partition in partitions:
|
||||
# direct call
|
||||
if direct_use:
|
||||
if user == partition:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
for i, arg in enumerate(partition.args):
|
||||
if arg == node:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
# getitem call
|
||||
else:
|
||||
if user in partition.args:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
for i, arg in enumerate(partition.args):
|
||||
if arg == user:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
|
||||
# user is output
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
if user.op == output_node.op:
|
||||
output_keys = {}
|
||||
partition_id = partition_name_to_id('', is_output=True)
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
|
||||
for i, arg in enumerate(output_keys):
|
||||
if arg == node:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
return p_output_val
|
||||
|
||||
|
||||
def get_topology(gm: GraphModule):
|
||||
topo = Topo()
|
||||
topo_output_partition = Partition()
|
||||
|
||||
input_partitions = []
|
||||
partitions = []
|
||||
output_partitions = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_partitions.append(node)
|
||||
elif node.name.startswith('submod_'):
|
||||
partitions.append(node)
|
||||
elif node.op == 'output':
|
||||
output_partitions.append(node)
|
||||
else:
|
||||
continue
|
||||
|
||||
# set output for input_partition
|
||||
topo_input_partition = Partition()
|
||||
for partition in input_partitions:
|
||||
cur_node = partition
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_input_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=0, partition=topo_input_partition)
|
||||
topo.set_input_partition_id(partition_id=0)
|
||||
|
||||
for i, partition in enumerate(partitions):
|
||||
topo_mid_partition = Partition()
|
||||
# set input for submodule
|
||||
for arg in partition.args:
|
||||
cur_node = arg
|
||||
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
|
||||
topo_mid_partition.add_input_val(p_input_val)
|
||||
# set output for submodule
|
||||
direct_use = True
|
||||
for user in partition.users:
|
||||
if user.name.startswith('getitem'):
|
||||
direct_use = False
|
||||
break
|
||||
if direct_use:
|
||||
cur_node = partition
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_mid_partition.add_output_val(p_output_val)
|
||||
else:
|
||||
for user in partition.users:
|
||||
cur_node = user
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_mid_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition)
|
||||
|
||||
# set input for output_partition
|
||||
for partition in output_partitions:
|
||||
topo_output_partition = Partition()
|
||||
torch.fx.graph.map_arg(
|
||||
partition.args[0],
|
||||
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
|
||||
topo.set_partitions(partition_id=1, partition=topo_output_partition)
|
||||
topo.set_output_partition_id(partition_id=1)
|
||||
|
||||
return topo
|
214
colossalai/legacy/pipeline/middleware/topo.py
Normal file
214
colossalai/legacy/pipeline/middleware/topo.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
# This file includes data structure used by Pipeline Middleware.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValPosition:
|
||||
partition_id: int
|
||||
offset: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class PartitionInputVal(object):
|
||||
|
||||
def __init__(self, partition_id, offset) -> None:
|
||||
# every input from which partition_id and which offset
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
self._from_partition_and_offset: ValPosition = val_pos
|
||||
|
||||
def get(self):
|
||||
return self._from_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f'<-({self._from_partition_and_offset})'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class PartitionOutputVal(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
# every output to which partition_id and which offset
|
||||
self._to_partition_and_offset: List[ValPosition] = []
|
||||
|
||||
def add(self, partition_id, offset):
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
self._to_partition_and_offset.append(val_pos)
|
||||
|
||||
def get(self):
|
||||
return self._to_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += '->('
|
||||
for val_pos in self._to_partition_and_offset:
|
||||
res += f'{val_pos},'
|
||||
res += ')'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class Partition(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_vals: List[PartitionInputVal] = []
|
||||
self._output_vals: List[PartitionOutputVal] = []
|
||||
|
||||
def add_input_val(self, input_val: PartitionInputVal):
|
||||
self._input_vals.append(input_val)
|
||||
|
||||
def add_output_val(self, output_val: PartitionOutputVal):
|
||||
self._output_vals.append(output_val)
|
||||
|
||||
def get_input_vals(self):
|
||||
return self._input_vals
|
||||
|
||||
def get_output_vals(self):
|
||||
return self._output_vals
|
||||
|
||||
# get the output offsets sent to dst_partition_id
|
||||
def get_output_offsets(self, dst_partition_id):
|
||||
res = []
|
||||
for offset, output_val in enumerate(self._output_vals):
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id == dst_partition_id:
|
||||
res.append(offset)
|
||||
|
||||
return res
|
||||
|
||||
# get all input dst partition_ids
|
||||
def get_input_partition_ids(self):
|
||||
res = []
|
||||
for input_val in self._input_vals:
|
||||
val_pos = input_val.get()
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
# get all output dst partition_ids
|
||||
def get_output_partition_ids(self):
|
||||
res = []
|
||||
for output_val in self._output_vals:
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f' input:\n'
|
||||
res += f' length:{len(self._input_vals)}\n'
|
||||
for i, input_val in enumerate(self._input_vals):
|
||||
res += f' offset={i}:{input_val}\n'
|
||||
|
||||
res += f' output:\n'
|
||||
res += f' length:{len(self._output_vals)}\n'
|
||||
for i, output_val in enumerate(self._output_vals):
|
||||
res += f' offset={i}:{output_val}\n'
|
||||
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
# This class is a middleware between partition splitter
|
||||
# and Pipeline Scheduler. It records the graph info about
|
||||
# partition input/output and provides it to scheduler.
|
||||
# There are three kinds of partition in Pipeline Middleware Design
|
||||
# which represents the whole process of a model execution: input-fwd-output
|
||||
# 1. input_partition: records the input of a model.
|
||||
# 2. mid_partition: record the splitted forwards execution of a model.
|
||||
# 3. output_partition: records the output of a model.
|
||||
# attributes:
|
||||
# _partitions: include all partitions
|
||||
# _input_partition_id: the key represents input_partition
|
||||
# _output_partition_id: the key represents output_partition
|
||||
class Topo(object):
|
||||
|
||||
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
|
||||
self._partitions: Dict[int, Partition] = {}
|
||||
self._input_partition_id = input_partition_id
|
||||
self._output_partition_id = output_partition_id
|
||||
|
||||
def set_input_partition_id(self, partition_id: int):
|
||||
self._input_partition_id = partition_id
|
||||
|
||||
def set_output_partition_id(self, partition_id: int):
|
||||
self._output_partition_id = partition_id
|
||||
|
||||
def get_input_partition_id(self):
|
||||
return self._input_partition_id
|
||||
|
||||
def get_output_partition_id(self):
|
||||
return self._output_partition_id
|
||||
|
||||
def set_partitions(self, partition_id: int, partition: Partition):
|
||||
self._partitions[partition_id] = partition
|
||||
|
||||
def get_mid_partitions(self):
|
||||
res = {} #{partition_id: Partition}
|
||||
for partition_id, partition in self._partitions.items():
|
||||
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
|
||||
continue
|
||||
res[partition_id] = partition
|
||||
return res
|
||||
|
||||
def get_mid_partition_ids(self):
|
||||
return list(self.get_mid_partitions().keys())
|
||||
|
||||
def get_input_partition(self):
|
||||
if self._input_partition_id is not None:
|
||||
return self._partitions[self._input_partition_id]
|
||||
return None
|
||||
|
||||
def get_output_partition(self):
|
||||
if self._output_partition_id is not None:
|
||||
return self._partitions[self._output_partition_id]
|
||||
return None
|
||||
|
||||
def get_partition_by_id(self, partition_id):
|
||||
return self._partitions[partition_id]
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
if len(self._partitions) == 0:
|
||||
return 'Empty Topo Graph.'
|
||||
|
||||
input_part = self.get_input_partition()
|
||||
if input_part is not None:
|
||||
res += '{\n'
|
||||
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
|
||||
res += '}\n'
|
||||
|
||||
mid_parts = self.get_mid_partitions()
|
||||
for i, (partition_id, part) in enumerate(mid_parts.items()):
|
||||
res += '{\n'
|
||||
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
|
||||
res += '}\n'
|
||||
|
||||
output_part = self.get_output_partition()
|
||||
if output_part is not None:
|
||||
res += '{\n'
|
||||
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
|
||||
res += '}\n'
|
||||
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
Reference in New Issue
Block a user