mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
4
colossalai/legacy/pipeline/__init__.py
Normal file
4
colossalai/legacy/pipeline/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .layer_spec import LayerSpec
|
||||
from .pipelinable import PipelinableContext, PipelinableModel
|
||||
|
||||
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
|
57
colossalai/legacy/pipeline/layer_spec.py
Normal file
57
colossalai/legacy/pipeline/layer_spec.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
|
||||
from colossalai.utils.model.utils import call_to_str
|
||||
|
||||
|
||||
class LayerSpec:
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, typename, *module_args, **module_kwargs):
|
||||
self.typename = typename
|
||||
self.module_args = module_args
|
||||
self.module_kwargs = module_kwargs
|
||||
self.children = None
|
||||
self._param_count = 0
|
||||
|
||||
if not issubclass(typename, torch.nn.Module):
|
||||
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
||||
|
||||
def __repr__(self):
|
||||
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
||||
|
||||
@property
|
||||
def param_count(self):
|
||||
return self._param_count
|
||||
|
||||
def build(self):
|
||||
"""Build the stored specification."""
|
||||
|
||||
recovered_args = []
|
||||
for obj in self.module_args:
|
||||
if isinstance(obj, LayerSpec):
|
||||
obj = obj.build()
|
||||
recovered_args.append(obj)
|
||||
recovered_args = tuple(recovered_args)
|
||||
|
||||
recovered_kwargs = {}
|
||||
for k, v in self.module_kwargs.items():
|
||||
if isinstance(v, LayerSpec):
|
||||
v = v.build()
|
||||
recovered_kwargs[k] = v
|
||||
|
||||
return self.typename(*recovered_args, **recovered_kwargs)
|
||||
|
||||
def set_children(self, children):
|
||||
self.children = children
|
||||
|
||||
def count_params(self):
|
||||
self._param_count = 0
|
||||
layer = self.build()
|
||||
for param in layer.parameters():
|
||||
self._param_count += param.numel()
|
||||
return self._param_count
|
||||
|
||||
def reset_param_count(self):
|
||||
self._param_count = 0
|
3
colossalai/legacy/pipeline/middleware/__init__.py
Normal file
3
colossalai/legacy/pipeline/middleware/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
|
||||
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
|
@@ -0,0 +1,3 @@
|
||||
from .fx import get_topology as get_fx_topology
|
||||
|
||||
__all__ = ['get_fx_topology']
|
153
colossalai/legacy/pipeline/middleware/adaptor/fx.py
Normal file
153
colossalai/legacy/pipeline/middleware/adaptor/fx.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
|
||||
|
||||
def partition_name_to_id(partition_name, is_input=False, is_output=False):
|
||||
if is_input:
|
||||
partition_id = 0
|
||||
elif is_output:
|
||||
partition_id = 1
|
||||
else:
|
||||
prefix = 'submod_'
|
||||
partition_id = int(partition_name.split(prefix)[-1]) + 2
|
||||
return partition_id
|
||||
|
||||
|
||||
# There are two kinds of def in fx.graph
|
||||
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
|
||||
# e.g. submod1 = call_module(...)
|
||||
# temporary_val = submod1[0]
|
||||
# submod2 = call_module(temporary_val, ...)
|
||||
# 2. direct_use & direct_def, which means the output is used by next partition directly.
|
||||
# e.g. submod1 = call_module(...)
|
||||
# submod2 = call_module(submod1, ...)
|
||||
|
||||
|
||||
def find_input_in_partition(node, partitions, input_partitions=None):
|
||||
p_input_val = None
|
||||
direct_def = not node.name.startswith('getitem')
|
||||
# search in input
|
||||
if direct_def and input_partitions is not None:
|
||||
partition_id = partition_name_to_id('', is_input=True)
|
||||
for i, input_node in enumerate(input_partitions):
|
||||
if input_node == node:
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
|
||||
return p_input_val
|
||||
# search submod in mid part
|
||||
if direct_def:
|
||||
for partition in partitions:
|
||||
if partition == node:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
|
||||
return p_input_val
|
||||
# search temporary value in graph
|
||||
else:
|
||||
for partition in partitions:
|
||||
for offset, mid_val in enumerate(partition.users):
|
||||
if mid_val == node:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
|
||||
return p_input_val
|
||||
|
||||
return p_input_val
|
||||
|
||||
|
||||
def find_output_in_partition(node, partitions, output_partitions=None):
|
||||
p_output_val = PartitionOutputVal()
|
||||
for user in node.users:
|
||||
direct_use = not user.name.startswith('getitem')
|
||||
# user is mid partition
|
||||
for partition in partitions:
|
||||
# direct call
|
||||
if direct_use:
|
||||
if user == partition:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
for i, arg in enumerate(partition.args):
|
||||
if arg == node:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
# getitem call
|
||||
else:
|
||||
if user in partition.args:
|
||||
partition_id = partition_name_to_id(partition.name)
|
||||
for i, arg in enumerate(partition.args):
|
||||
if arg == user:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
|
||||
# user is output
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
if user.op == output_node.op:
|
||||
output_keys = {}
|
||||
partition_id = partition_name_to_id('', is_output=True)
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
|
||||
for i, arg in enumerate(output_keys):
|
||||
if arg == node:
|
||||
p_output_val.add(partition_id=partition_id, offset=i)
|
||||
break
|
||||
return p_output_val
|
||||
|
||||
|
||||
def get_topology(gm: GraphModule):
|
||||
topo = Topo()
|
||||
topo_output_partition = Partition()
|
||||
|
||||
input_partitions = []
|
||||
partitions = []
|
||||
output_partitions = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_partitions.append(node)
|
||||
elif node.name.startswith('submod_'):
|
||||
partitions.append(node)
|
||||
elif node.op == 'output':
|
||||
output_partitions.append(node)
|
||||
else:
|
||||
continue
|
||||
|
||||
# set output for input_partition
|
||||
topo_input_partition = Partition()
|
||||
for partition in input_partitions:
|
||||
cur_node = partition
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_input_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=0, partition=topo_input_partition)
|
||||
topo.set_input_partition_id(partition_id=0)
|
||||
|
||||
for i, partition in enumerate(partitions):
|
||||
topo_mid_partition = Partition()
|
||||
# set input for submodule
|
||||
for arg in partition.args:
|
||||
cur_node = arg
|
||||
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
|
||||
topo_mid_partition.add_input_val(p_input_val)
|
||||
# set output for submodule
|
||||
direct_use = True
|
||||
for user in partition.users:
|
||||
if user.name.startswith('getitem'):
|
||||
direct_use = False
|
||||
break
|
||||
if direct_use:
|
||||
cur_node = partition
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_mid_partition.add_output_val(p_output_val)
|
||||
else:
|
||||
for user in partition.users:
|
||||
cur_node = user
|
||||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_mid_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition)
|
||||
|
||||
# set input for output_partition
|
||||
for partition in output_partitions:
|
||||
topo_output_partition = Partition()
|
||||
torch.fx.graph.map_arg(
|
||||
partition.args[0],
|
||||
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
|
||||
topo.set_partitions(partition_id=1, partition=topo_output_partition)
|
||||
topo.set_output_partition_id(partition_id=1)
|
||||
|
||||
return topo
|
214
colossalai/legacy/pipeline/middleware/topo.py
Normal file
214
colossalai/legacy/pipeline/middleware/topo.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
# This file includes data structure used by Pipeline Middleware.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValPosition:
|
||||
partition_id: int
|
||||
offset: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class PartitionInputVal(object):
|
||||
|
||||
def __init__(self, partition_id, offset) -> None:
|
||||
# every input from which partition_id and which offset
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
self._from_partition_and_offset: ValPosition = val_pos
|
||||
|
||||
def get(self):
|
||||
return self._from_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f'<-({self._from_partition_and_offset})'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class PartitionOutputVal(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
# every output to which partition_id and which offset
|
||||
self._to_partition_and_offset: List[ValPosition] = []
|
||||
|
||||
def add(self, partition_id, offset):
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
self._to_partition_and_offset.append(val_pos)
|
||||
|
||||
def get(self):
|
||||
return self._to_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += '->('
|
||||
for val_pos in self._to_partition_and_offset:
|
||||
res += f'{val_pos},'
|
||||
res += ')'
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class Partition(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_vals: List[PartitionInputVal] = []
|
||||
self._output_vals: List[PartitionOutputVal] = []
|
||||
|
||||
def add_input_val(self, input_val: PartitionInputVal):
|
||||
self._input_vals.append(input_val)
|
||||
|
||||
def add_output_val(self, output_val: PartitionOutputVal):
|
||||
self._output_vals.append(output_val)
|
||||
|
||||
def get_input_vals(self):
|
||||
return self._input_vals
|
||||
|
||||
def get_output_vals(self):
|
||||
return self._output_vals
|
||||
|
||||
# get the output offsets sent to dst_partition_id
|
||||
def get_output_offsets(self, dst_partition_id):
|
||||
res = []
|
||||
for offset, output_val in enumerate(self._output_vals):
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id == dst_partition_id:
|
||||
res.append(offset)
|
||||
|
||||
return res
|
||||
|
||||
# get all input dst partition_ids
|
||||
def get_input_partition_ids(self):
|
||||
res = []
|
||||
for input_val in self._input_vals:
|
||||
val_pos = input_val.get()
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
# get all output dst partition_ids
|
||||
def get_output_partition_ids(self):
|
||||
res = []
|
||||
for output_val in self._output_vals:
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f' input:\n'
|
||||
res += f' length:{len(self._input_vals)}\n'
|
||||
for i, input_val in enumerate(self._input_vals):
|
||||
res += f' offset={i}:{input_val}\n'
|
||||
|
||||
res += f' output:\n'
|
||||
res += f' length:{len(self._output_vals)}\n'
|
||||
for i, output_val in enumerate(self._output_vals):
|
||||
res += f' offset={i}:{output_val}\n'
|
||||
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
# This class is a middleware between partition splitter
|
||||
# and Pipeline Scheduler. It records the graph info about
|
||||
# partition input/output and provides it to scheduler.
|
||||
# There are three kinds of partition in Pipeline Middleware Design
|
||||
# which represents the whole process of a model execution: input-fwd-output
|
||||
# 1. input_partition: records the input of a model.
|
||||
# 2. mid_partition: record the splitted forwards execution of a model.
|
||||
# 3. output_partition: records the output of a model.
|
||||
# attributes:
|
||||
# _partitions: include all partitions
|
||||
# _input_partition_id: the key represents input_partition
|
||||
# _output_partition_id: the key represents output_partition
|
||||
class Topo(object):
|
||||
|
||||
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
|
||||
self._partitions: Dict[int, Partition] = {}
|
||||
self._input_partition_id = input_partition_id
|
||||
self._output_partition_id = output_partition_id
|
||||
|
||||
def set_input_partition_id(self, partition_id: int):
|
||||
self._input_partition_id = partition_id
|
||||
|
||||
def set_output_partition_id(self, partition_id: int):
|
||||
self._output_partition_id = partition_id
|
||||
|
||||
def get_input_partition_id(self):
|
||||
return self._input_partition_id
|
||||
|
||||
def get_output_partition_id(self):
|
||||
return self._output_partition_id
|
||||
|
||||
def set_partitions(self, partition_id: int, partition: Partition):
|
||||
self._partitions[partition_id] = partition
|
||||
|
||||
def get_mid_partitions(self):
|
||||
res = {} #{partition_id: Partition}
|
||||
for partition_id, partition in self._partitions.items():
|
||||
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
|
||||
continue
|
||||
res[partition_id] = partition
|
||||
return res
|
||||
|
||||
def get_mid_partition_ids(self):
|
||||
return list(self.get_mid_partitions().keys())
|
||||
|
||||
def get_input_partition(self):
|
||||
if self._input_partition_id is not None:
|
||||
return self._partitions[self._input_partition_id]
|
||||
return None
|
||||
|
||||
def get_output_partition(self):
|
||||
if self._output_partition_id is not None:
|
||||
return self._partitions[self._output_partition_id]
|
||||
return None
|
||||
|
||||
def get_partition_by_id(self, partition_id):
|
||||
return self._partitions[partition_id]
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
if len(self._partitions) == 0:
|
||||
return 'Empty Topo Graph.'
|
||||
|
||||
input_part = self.get_input_partition()
|
||||
if input_part is not None:
|
||||
res += '{\n'
|
||||
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
|
||||
res += '}\n'
|
||||
|
||||
mid_parts = self.get_mid_partitions()
|
||||
for i, (partition_id, part) in enumerate(mid_parts.items()):
|
||||
res += '{\n'
|
||||
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
|
||||
res += '}\n'
|
||||
|
||||
output_part = self.get_output_partition()
|
||||
if output_part is not None:
|
||||
res += '{\n'
|
||||
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
|
||||
res += '}\n'
|
||||
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
263
colossalai/legacy/pipeline/pipelinable.py
Normal file
263
colossalai/legacy/pipeline/pipelinable.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.utils import CheckpointModule
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
from .layer_spec import LayerSpec
|
||||
from .utils import (
|
||||
build_kwargs_for_module,
|
||||
call_module,
|
||||
customized_partition,
|
||||
exec_funcs_with_kwargs,
|
||||
partition_balanced,
|
||||
partition_uniform,
|
||||
)
|
||||
|
||||
|
||||
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
A context manager to split the model into pipeline stages.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: str = "balanced"):
|
||||
super().__init__()
|
||||
self._layer_spec_dict = {}
|
||||
self._root_children = None
|
||||
self._model = None
|
||||
self._layer_spec_list = []
|
||||
self._func_dict = {}
|
||||
self._policy = policy
|
||||
|
||||
@property
|
||||
def policy(self):
|
||||
return self._policy
|
||||
|
||||
@policy.setter
|
||||
def policy(self, policy: str):
|
||||
self._policy = policy
|
||||
|
||||
@property
|
||||
def layers_count(self):
|
||||
return len(self._layer_spec_list)
|
||||
|
||||
@property
|
||||
def funcs_count(self):
|
||||
return len(self._func_dict)
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
"""
|
||||
# reserve rng states
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""
|
||||
The callback function when exiting context.
|
||||
"""
|
||||
|
||||
# reset rng states
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
# iterate over the positional arguments
|
||||
# to check if an argument is a torch Module
|
||||
# if found any torch Module, replace it with its layer spec
|
||||
# for storage purpose
|
||||
modified_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
|
||||
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
|
||||
if id(arg) in self._layer_spec_dict:
|
||||
arg = self._layer_spec_dict[id(arg)]
|
||||
|
||||
modified_args.append(arg)
|
||||
|
||||
# to the same for the keyword arguments
|
||||
modified_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.nn.Module):
|
||||
v = self._layer_spec_dict[id(v)]
|
||||
# (lyl)TODO: analyze ColoTensor as well
|
||||
modified_kwargs[k] = v
|
||||
|
||||
# keep track of the module children
|
||||
# as torch.nn.Module.__init__ is called from inner module to outer module,
|
||||
# the final value of self._model will be the outermost model
|
||||
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
|
||||
# will be the ``ResNet`` object.
|
||||
self._root_children = list(module.children())
|
||||
self._model = module
|
||||
|
||||
# store the children to keep the module hierarchy
|
||||
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
|
||||
layer_spec.set_children(module.children())
|
||||
|
||||
# store the layer spec in this context
|
||||
module_id = id(module)
|
||||
self._layer_spec_dict[module_id] = layer_spec
|
||||
|
||||
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
|
||||
name_list = []
|
||||
for name, param in module.named_parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
continue
|
||||
name_list.append((name, param))
|
||||
|
||||
for name, param in name_list:
|
||||
if hasattr(module, name):
|
||||
delattr(module, name)
|
||||
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
|
||||
|
||||
def to_layer_list(self, exec_seq=None):
|
||||
"""
|
||||
Create a layer spec list and func list with execution sequence given by user.
|
||||
If exec_seq is None, we will take the module initializing order as execution order.
|
||||
"""
|
||||
|
||||
self._exec_seq = exec_seq
|
||||
if exec_seq is None:
|
||||
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
|
||||
children_name = []
|
||||
for child in self._root_children:
|
||||
layer_spec = self._layer_spec_dict[id(child)]
|
||||
if layer_spec.typename in (
|
||||
torch.nn.modules.container.ModuleList,
|
||||
torch.nn.modules.container.Sequential,
|
||||
):
|
||||
for child_in_container in layer_spec.children:
|
||||
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
||||
for name, module in self._model.named_modules():
|
||||
if id(module) == id(child_in_container):
|
||||
children_name.append(name)
|
||||
break
|
||||
else:
|
||||
self._layer_spec_list.append(layer_spec)
|
||||
for name, module in self._model.named_modules():
|
||||
if id(module) == id(child):
|
||||
children_name.append(name)
|
||||
break
|
||||
|
||||
else:
|
||||
front_funcs_list = []
|
||||
named_modules = dict(self._model.named_modules())
|
||||
for index, element in enumerate(exec_seq):
|
||||
if isinstance(element, str):
|
||||
if element == "SPLIT_NODE":
|
||||
continue
|
||||
assert (
|
||||
element in named_modules
|
||||
), f"Found invalid module name {element}, please check if you spell the module name correctly."
|
||||
|
||||
# get the layer spec based on the module ID
|
||||
module = named_modules[element]
|
||||
layer_spec = self._layer_spec_dict[id(module)]
|
||||
|
||||
# check whether there are functions which should be executed before this module
|
||||
if len(front_funcs_list) != 0:
|
||||
func_key = (layer_spec, "front")
|
||||
if func_key not in self._func_dict:
|
||||
self._func_dict[func_key] = []
|
||||
for f in front_funcs_list:
|
||||
self._func_dict[func_key].append(f)
|
||||
front_funcs_list = []
|
||||
|
||||
func_key = (layer_spec, "behind")
|
||||
self._layer_spec_list.append(layer_spec)
|
||||
elif isinstance(element, tuple) and element[1] == "front":
|
||||
front_funcs_list.append(element[0])
|
||||
else:
|
||||
if func_key not in self._func_dict:
|
||||
self._func_dict[func_key] = []
|
||||
if isinstance(element, tuple):
|
||||
self._func_dict[func_key].append(element[0])
|
||||
else:
|
||||
self._func_dict[func_key].append(element)
|
||||
|
||||
def partition(self, num_chunks, pipeline_size, rank):
|
||||
"""
|
||||
Partitioned model will be built respect to partition policy.
|
||||
The real module instance will be built in this method.
|
||||
"""
|
||||
if isinstance(self._policy, str):
|
||||
if self._policy == "uniform":
|
||||
parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
|
||||
elif self._policy == "balanced":
|
||||
param_counts = []
|
||||
for layer_spec in self._layer_spec_list:
|
||||
param_counts.append(layer_spec.count_params())
|
||||
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
|
||||
elif self._policy == "customized":
|
||||
assert (self._exec_seq
|
||||
is not None), f"An explicit exec_seq must be defined by user in customized policy mode."
|
||||
self.customized_parts = customized_partition(self._exec_seq)
|
||||
assert len(self.customized_parts) == gpc.get_world_size(
|
||||
ParallelMode.PIPELINE
|
||||
), f"World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}"
|
||||
parts = self.customized_parts[rank]
|
||||
else:
|
||||
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
|
||||
elif isinstance(self._policy, dict):
|
||||
parts = self._policy[rank]
|
||||
else:
|
||||
raise ValueError("A partition policy should be either a string or a dictionary.")
|
||||
|
||||
layers_to_build = []
|
||||
for start, end in parts:
|
||||
layers_to_build += self._layer_spec_list[start:end]
|
||||
behind_func_dict_in_partition = {}
|
||||
front_func_dict_in_partition = {}
|
||||
module_list_in_partition = []
|
||||
for layer in layers_to_build:
|
||||
module = layer.build()
|
||||
module_list_in_partition.append(module)
|
||||
if (layer, "front") in self._func_dict:
|
||||
front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")]
|
||||
elif (layer, "behind") in self._func_dict:
|
||||
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
|
||||
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
||||
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
|
||||
behind_func_dict_in_partition)
|
||||
|
||||
return pipeline_model
|
||||
|
||||
|
||||
class PipelinableModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, module_list, front_func_dict, behind_func_dict):
|
||||
super().__init__()
|
||||
self._module_list = module_list
|
||||
self._front_func_dict = front_func_dict
|
||||
self._behind_func_dict = behind_func_dict
|
||||
|
||||
def forward(self, *input_tensor, **kwargs):
|
||||
for module in self._module_list:
|
||||
if id(module) in self._front_func_dict:
|
||||
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
|
||||
|
||||
if isinstance(module, CheckpointModule):
|
||||
forward_func = module._forward
|
||||
else:
|
||||
forward_func = module.forward
|
||||
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
|
||||
if input_tensor is None:
|
||||
input_tensor = call_module(module, kwargs=module_kwargs)
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
|
||||
else:
|
||||
input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
|
||||
|
||||
if id(module) in self._behind_func_dict:
|
||||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
168
colossalai/legacy/pipeline/pipeline_process_group.py
Normal file
168
colossalai/legacy/pipeline/pipeline_process_group.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import rpc
|
||||
|
||||
from colossalai.legacy.tensor import ProcessGroup
|
||||
|
||||
|
||||
class PipelineProcessGroup:
|
||||
# TODO : flexible API for DP size and TP size
|
||||
# In the future design mode, dp_degree and tp_degree should be removed
|
||||
def __init__(self) -> None:
|
||||
self.is_initialize = False
|
||||
|
||||
def set_global_info(self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
dp_degree: int = 1,
|
||||
tp_degree: int = 1,
|
||||
num_worker_threads: int = 1,
|
||||
device: str = "cuda") -> None:
|
||||
|
||||
device_mesh_size = dp_degree * tp_degree
|
||||
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
|
||||
self._num_worker_threads = num_worker_threads
|
||||
|
||||
self._device_mesh_size = device_mesh_size
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._dp_degree = dp_degree
|
||||
self._tp_degree = tp_degree
|
||||
self.device = device
|
||||
self._stage_num = world_size // device_mesh_size
|
||||
self._pp_rank = rank // device_mesh_size
|
||||
self._pp_ranks = [(rank % device_mesh_size) + i * device_mesh_size for i in range(self._stage_num)]
|
||||
self._local_stage_ranks = [(rank // device_mesh_size * device_mesh_size) + i for i in range(device_mesh_size)]
|
||||
|
||||
# pp_ranks
|
||||
self._initialize_pp_process_group()
|
||||
|
||||
# initialise tp dp process groups
|
||||
self._initialize_tp_dp_process_group()
|
||||
|
||||
# status
|
||||
self._is_first_pp_rank = self._pp_rank == 0
|
||||
self._is_last_pp_rank = self._pp_rank == self._stage_num - 1
|
||||
|
||||
self.is_initialize = True
|
||||
|
||||
# lock
|
||||
self.initialise_lock = threading.Lock()
|
||||
self.chimera_lock = threading.Lock()
|
||||
|
||||
def _initialize_process_group(self):
|
||||
stage_num = self.get_stage_num()
|
||||
if stage_num == 1:
|
||||
return
|
||||
device = self.device
|
||||
world_size = self.get_world_size()
|
||||
rank = self.get_global_rank()
|
||||
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group')
|
||||
|
||||
def _initialize_pp_process_group(self) -> None:
|
||||
rank = self.get_global_rank()
|
||||
world_size = self.get_world_size()
|
||||
|
||||
# build rpc connection
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads)
|
||||
|
||||
for pp_rank in self._pp_ranks:
|
||||
options.set_device_map(f'work{pp_rank}', {rank: pp_rank})
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
|
||||
def _initialize_tp_dp_process_group(self) -> None:
|
||||
rank = self.get_global_rank()
|
||||
local_stage_ranks = self.get_local_stage_global_ranks()
|
||||
dp_degree = self.get_dp_degree()
|
||||
tp_degree = self.get_tp_degree()
|
||||
self._tp_dp_process_group = ProcessGroup(rank, local_stage_ranks, tp_degree, dp_degree)
|
||||
|
||||
def get_global_rank(self):
|
||||
return self._rank
|
||||
|
||||
def get_world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def get_dp_degree(self) -> int:
|
||||
return self._dp_degree
|
||||
|
||||
def get_tp_degree(self) -> int:
|
||||
return self._tp_degree
|
||||
|
||||
def get_local_device_mesh_size(self) -> int:
|
||||
return self._device_mesh_size
|
||||
|
||||
def get_device_mesh_num(self) -> int:
|
||||
pass
|
||||
|
||||
def get_stage_num(self) -> int:
|
||||
return self._stage_num
|
||||
|
||||
def is_first_stage(self) -> bool:
|
||||
return self._is_first_pp_rank
|
||||
|
||||
def is_last_stage(self) -> bool:
|
||||
return self._is_last_pp_rank
|
||||
|
||||
def check_pp_rank_valid(self, pp_rank: int) -> bool:
|
||||
return -1 < pp_rank < self._stage_num
|
||||
|
||||
def get_local_pp_rank(self) -> int:
|
||||
return self._pp_rank
|
||||
|
||||
def get_prev_pp_rank(self) -> int:
|
||||
prev_pp_rank = self._pp_rank - 1
|
||||
if not self.check_pp_rank_valid(prev_pp_rank):
|
||||
assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a previous stage!")
|
||||
return prev_pp_rank
|
||||
|
||||
def get_next_pp_rank(self) -> int:
|
||||
next_pp_rank = self._pp_rank + 1
|
||||
if not self.check_pp_rank_valid(next_pp_rank):
|
||||
assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a next stage!")
|
||||
return next_pp_rank
|
||||
|
||||
def get_local_stage_global_ranks(self) -> List[int]:
|
||||
return self._local_stage_ranks
|
||||
|
||||
def local_dp_rank(self) -> int:
|
||||
return self._tp_dp_process_group.dp_local_rank()
|
||||
|
||||
def local_tp_rank(self) -> int:
|
||||
return self._tp_dp_process_group.tp_local_rank()
|
||||
|
||||
def get_pp_global_ranks(self) -> int:
|
||||
return self._pp_ranks
|
||||
|
||||
def get_dp_global_ranks(self):
|
||||
pass
|
||||
|
||||
def get_tp_global_ranks(self):
|
||||
pass
|
||||
|
||||
def get_chimera_all_reduce_group(self, pp_rank: int):
|
||||
with self.chimera_lock:
|
||||
if not hasattr(self, 'chimera_groups'):
|
||||
world_size = self.get_world_size()
|
||||
stage_num = self.get_stage_num()
|
||||
assert world_size % 2 == 0, 'world_size must be even in chimera!'
|
||||
self.chimera_groups = {}
|
||||
for rank in range(world_size // 2):
|
||||
pair = [rank, world_size - 1 - rank]
|
||||
group = dist.new_group(pair)
|
||||
self.chimera_groups[pair[0]] = group
|
||||
self.chimera_groups[pair[1]] = group
|
||||
self.chimera_groups[pair[0] + stage_num] = group
|
||||
self.chimera_groups[pair[1] + stage_num] = group
|
||||
self.chimera_step_lock = threading.Lock()
|
||||
self.chimera_step_lock.acquire()
|
||||
|
||||
return self.chimera_groups[pp_rank]
|
||||
|
||||
|
||||
ppg = PipelineProcessGroup()
|
4
colossalai/legacy/pipeline/rpc/__init__.py
Normal file
4
colossalai/legacy/pipeline/rpc/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from .utils import pytree_map
|
||||
|
||||
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
|
1309
colossalai/legacy/pipeline/rpc/_pipeline_base.py
Normal file
1309
colossalai/legacy/pipeline/rpc/_pipeline_base.py
Normal file
File diff suppressed because it is too large
Load Diff
346
colossalai/legacy/pipeline/rpc/_pipeline_schedule.py
Normal file
346
colossalai/legacy/pipeline/rpc/_pipeline_schedule.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import threading
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.legacy.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
|
||||
|
||||
# Implementation of different Pipeline schedule
|
||||
# <strategy>Worker defines the worker for each stage
|
||||
# <strategy>PipelineEngine is the class for use
|
||||
|
||||
|
||||
class FillDrainWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
num_microbatches = self.num_microbatches
|
||||
|
||||
if self.forward_times < num_microbatches:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
else:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
return target_key
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
use_1F1B = False
|
||||
|
||||
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
|
||||
|
||||
class OneFOneBWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
pp_rank = self.pp_rank
|
||||
actual_stage_num = self.actual_stage_num
|
||||
num_microbatches = self.num_microbatches
|
||||
is_last_stage = pp_rank == actual_stage_num - 1
|
||||
|
||||
if self.outstanding <= self.outstanding_range[0]:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
elif self.outstanding >= self.outstanding_range[1]:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else:
|
||||
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
# change outstanding_range at:
|
||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||
if not is_last_stage and \
|
||||
target_key.phase == Phase.FORWARD:
|
||||
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
|
||||
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
outstanding_max = actual_stage_num - pp_rank
|
||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||
if target_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
return target_key
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
|
||||
|
||||
class ChimeraWorker(WorkerBase):
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
rank = self.pp_rank
|
||||
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
|
||||
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
|
||||
|
||||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
||||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
||||
|
||||
# should be arranged in order, the order of the input of current forward
|
||||
self.producer_stage_ids = []
|
||||
self.consumer_stage_ids = []
|
||||
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= min_pp_rank:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= max_pp_rank:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
pp_rank = self.pp_rank
|
||||
stage_num = self.actual_stage_num
|
||||
real_microbatch_num = self.num_microbatches // 2
|
||||
|
||||
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
|
||||
forward_block_num = self.forward_times // forward_block_size
|
||||
|
||||
if self.forward_times >= real_microbatch_num or \
|
||||
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else: # others
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
|
||||
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
|
||||
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
|
||||
real_target_microbatch_id = target_microbatch_id * 2
|
||||
if pp_rank >= stage_num:
|
||||
real_target_microbatch_id += 1
|
||||
target_key = UniqueKey(real_target_microbatch_id, target_phase)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
return target_key
|
||||
|
||||
def _initialize_partition(self):
|
||||
# In order to ensure the down pipeline share the same parameter
|
||||
# with the up pipeline, partition of down partition will be copied
|
||||
# from corresponding up stage
|
||||
pp_rank = self.pp_rank
|
||||
stage_num = self.actual_stage_num
|
||||
device = self.device
|
||||
if pp_rank < stage_num:
|
||||
super()._initialize_partition()
|
||||
else:
|
||||
# if it is down pipeline, create partition by origin method
|
||||
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
|
||||
# get the corresponding model state dict and wait for its init
|
||||
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
|
||||
super()._initialize_partition()
|
||||
self.module_partition.load_state_dict(state_dict)
|
||||
|
||||
# init group for chimera in ppg
|
||||
ppg.get_chimera_all_reduce_group(pp_rank)
|
||||
|
||||
# lock for step sync
|
||||
self.step_sync_lock = threading.Lock()
|
||||
self.step_sync_lock.acquire()
|
||||
|
||||
self.have_grad_lock = threading.Lock()
|
||||
self.have_grad_lock.acquire()
|
||||
|
||||
def _get_lock_gradient(self):
|
||||
self.have_grad_lock.acquire()
|
||||
grads = self.get_parameter_gradients()
|
||||
self.step_sync_lock.release()
|
||||
return grads
|
||||
|
||||
def is_first_stage(self):
|
||||
return (self.pp_rank % self.actual_stage_num) == 0
|
||||
|
||||
def is_last_stage(self):
|
||||
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
|
||||
|
||||
def _is_last_step(self, work_item: WorkItem) -> bool:
|
||||
if work_item.forward_only:
|
||||
last_phase = Phase.FORWARD
|
||||
else:
|
||||
last_phase = Phase.BACKWARD
|
||||
is_last_phase = work_item.phase == last_phase
|
||||
last_microbatch_id = self.num_microbatches - 1
|
||||
if self.pp_rank < self.actual_stage_num:
|
||||
last_microbatch_id -= 1
|
||||
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
|
||||
return is_last_phase and is_last_microbatch
|
||||
|
||||
def _get_step_order(self) -> List[int]:
|
||||
# TODO : If you want to extend it to multi head chimera, overwrite here
|
||||
stage_num = self.actual_stage_num
|
||||
pp_rank = self.pp_rank
|
||||
# pp_rank in the same device
|
||||
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
|
||||
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
|
||||
return local_device_pp_ranks
|
||||
|
||||
def _hook_before_step(self):
|
||||
self.have_grad_lock.release()
|
||||
pp_rank = self.pp_rank
|
||||
stage_num = self.actual_stage_num
|
||||
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
|
||||
|
||||
# if current pp_rank is not the first to do step
|
||||
# wait its previous pp_rank finish step
|
||||
grads = self.get_parameter_gradients()
|
||||
|
||||
# send
|
||||
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
|
||||
co_grads = co_worker.rpc_sync()._get_lock_gradient()
|
||||
# sync
|
||||
self.step_sync_lock.acquire()
|
||||
for i in range(len(grads)):
|
||||
grads[i] += co_grads[i]
|
||||
|
||||
|
||||
class ChimeraPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
use_1F1B = False
|
||||
chunk = 1
|
||||
|
||||
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||
output_pp_ranks: List[int], ret_future):
|
||||
pass
|
||||
|
||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||
stage_num = self.stage_num
|
||||
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
|
||||
for pp_rank in range(stage_num):
|
||||
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
|
||||
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
|
||||
|
||||
def _create_pp_rank_to_module_partition_id(self) -> None:
|
||||
stage_num = self.stage_num
|
||||
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
|
||||
for pp_rank in range(stage_num):
|
||||
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
|
||||
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
|
||||
|
||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||
num_microbatches = self.num_microbatches
|
||||
stage_num = self.stage_num
|
||||
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
||||
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
||||
# merge up and down
|
||||
return {**up_ret_future, **down_ret_future}
|
||||
|
||||
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
|
||||
# offset is 0 for all the ranks in up pipeline
|
||||
# offset is stage_num for all the ranks in down pipeline
|
||||
offset = (microbatch_id % 2) * self.stage_num
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
||||
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
||||
|
||||
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
|
||||
# offset is 0 for all the ranks in up pipeline
|
||||
# offset is stage_num for all the ranks in down pipeline
|
||||
offset = (microbatch_id % 2) * self.stage_num
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
||||
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||
|
||||
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
offset = (microbatch_id % 2) * self.stage_num
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
||||
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
|
||||
|
||||
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
||||
stage_num = self.stage_num
|
||||
num_microbatches = self.num_microbatches
|
||||
if not forward_only:
|
||||
for pp_rank in input_pp_ranks:
|
||||
up_last_microbatch_id = num_microbatches - 2
|
||||
down_last_microbatch_id = num_microbatches - 1
|
||||
|
||||
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
|
||||
|
||||
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
|
||||
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
|
||||
up_worker_rref.rpc_sync().get_output_by_key(up_key)
|
||||
down_worker_rref.rpc_sync().get_output_by_key(down_key)
|
||||
|
||||
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
|
||||
"""Logic of collection of forward in Chimera.
|
||||
Currently, only one input one output model is supported
|
||||
"""
|
||||
stage_num = self.stage_num
|
||||
forward_result = []
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_forward_result = [None] * self.num_microbatches
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
offset = (microbatch_id % 2) * stage_num
|
||||
ret = ret_future[pp_rank + offset][microbatch_id].wait()
|
||||
ret = [ret] if isinstance(ret, torch.Tensor) else ret
|
||||
worker_forward_result[microbatch_id] = ret
|
||||
|
||||
worker_forward_result = list(zip(*worker_forward_result))
|
||||
forward_result.extend(worker_forward_result)
|
||||
|
||||
return forward_result
|
155
colossalai/legacy/pipeline/rpc/utils.py
Normal file
155
colossalai/legacy/pipeline/rpc/utils.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.legacy.pipeline.pipeline_process_group import ppg
|
||||
|
||||
|
||||
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
|
||||
if isinstance(obj, process_types):
|
||||
return fn(obj)
|
||||
elif type(obj) is dict:
|
||||
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
|
||||
elif type(obj) is tuple:
|
||||
return tuple(pyobj_map(o, fn, process_types) for o in obj)
|
||||
elif type(obj) is list:
|
||||
return list(pyobj_map(o, fn, process_types) for o in obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||
"""process object recursively, like pytree
|
||||
|
||||
Args:
|
||||
obj (:class:`Any`): object to process
|
||||
fn (:class:`Callable`): a function to process subobject in obj
|
||||
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
||||
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
||||
|
||||
Returns:
|
||||
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||
elif isinstance(obj, list):
|
||||
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||
elif isinstance(obj, process_types):
|
||||
return fn(obj)
|
||||
else:
|
||||
return fn(obj) if map_all else obj
|
||||
|
||||
|
||||
def tensor_shape_list(obj):
|
||||
return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor)
|
||||
|
||||
|
||||
def get_batch_lengths(batch):
|
||||
lengths = []
|
||||
pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor)
|
||||
return lengths
|
||||
|
||||
|
||||
def split_batch(batch: Any, start, stop, device: str):
|
||||
if device == 'cuda':
|
||||
fn = lambda x: x[start:stop].cuda()
|
||||
else:
|
||||
fn = lambda x: x[start:stop]
|
||||
return pytree_map(batch, fn=fn, process_types=torch.Tensor)
|
||||
|
||||
|
||||
def type_detail(obj):
|
||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||
|
||||
|
||||
def pytree_filter(fn, obj, process_types):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
filters = []
|
||||
|
||||
def condition_append(obj):
|
||||
if fn(obj):
|
||||
filters.append(obj)
|
||||
|
||||
pytree_map(obj, fn=condition_append, process_types=process_types)
|
||||
return filters
|
||||
|
||||
|
||||
def get_real_args_kwargs(args_or_kwargs):
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
# TODO : combine producer and consumer
|
||||
# by default, merge all args in the output args or kwargs
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
args_or_kwargs = flatten_args
|
||||
|
||||
return args_or_kwargs
|
||||
|
||||
|
||||
def run_worker(rank, args, master_func):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
|
||||
device = args.device
|
||||
world_size = args.world_size
|
||||
dp_degree = args.dp_degree
|
||||
tp_degree = args.tp_degree
|
||||
num_worker_threads = args.num_worker_threads
|
||||
host = args.master_addr
|
||||
port = args.master_port
|
||||
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||
|
||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||
ppg.set_global_info(rank=rank,
|
||||
world_size=world_size,
|
||||
dp_degree=dp_degree,
|
||||
tp_degree=tp_degree,
|
||||
num_worker_threads=num_worker_threads,
|
||||
device=device)
|
||||
ppg.args = args
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
master_func(args)
|
||||
# barrier here
|
||||
if _is_current_rpc_agent_set():
|
||||
rpc.shutdown()
|
||||
else:
|
||||
warnings.warn("RPC has not been initialized")
|
||||
|
||||
|
||||
def rpc_run(args, master_func):
|
||||
world_size = args.world_size
|
||||
mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epoch', type=int, default=1)
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--dp_degree', type=int, default=1)
|
||||
parser.add_argument('--tp_degree', type=int, default=1)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
|
||||
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=int, default=128)
|
||||
return parser.parse_args()
|
276
colossalai/legacy/pipeline/utils.py
Normal file
276
colossalai/legacy/pipeline/utils.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import heapq
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.nn.layer.utils import CheckpointModule
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
def _binary_partition(weights: List, start: int, end: int):
|
||||
"""Returns the binary partition position of `weights`, given the start
|
||||
position `st` and the end position `ed`.
|
||||
|
||||
Args:
|
||||
weights (list): A python list to be binary partitioned
|
||||
start (int): the start position of the binary partition
|
||||
end (int): the end position of the binary partition
|
||||
|
||||
Returns:
|
||||
int: the binary partition position of `weights`
|
||||
"""
|
||||
w_sum = weights[end - 1]
|
||||
prefix = 0
|
||||
if start > 0:
|
||||
w_sum -= weights[start - 1]
|
||||
prefix = weights[start - 1]
|
||||
minimum = float("inf")
|
||||
for idx in range(start + 1, end):
|
||||
front = weights[idx - 1] - prefix
|
||||
diff = abs(w_sum - 2 * front)
|
||||
if diff < minimum:
|
||||
pos = idx
|
||||
minimum = diff
|
||||
|
||||
return start, pos, end
|
||||
|
||||
|
||||
def _heap_addition(weights: List, intervals: int, add_cnt: int):
|
||||
"""
|
||||
"""
|
||||
|
||||
def _heap_push(heap, st, ed):
|
||||
value = weights[ed - 1]
|
||||
if st > 0:
|
||||
value -= weights[st - 1]
|
||||
heapq.heappush(heap, (-value, st, ed))
|
||||
|
||||
ret_intervals = []
|
||||
heap = []
|
||||
|
||||
for st, ed in intervals:
|
||||
_heap_push(heap, st, ed)
|
||||
|
||||
while add_cnt > 0:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
if ed - st == 1:
|
||||
ret_intervals.append((st, ed))
|
||||
else:
|
||||
l, m, r = _binary_partition(weights, st, ed)
|
||||
_heap_push(heap, l, m)
|
||||
_heap_push(heap, m, r)
|
||||
add_cnt -= 1
|
||||
|
||||
while heap:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
ret_intervals.append((st, ed))
|
||||
|
||||
ret_intervals.sort()
|
||||
return ret_intervals
|
||||
|
||||
|
||||
def _calc_partitions(weights, value):
|
||||
prev = 0
|
||||
prefix = 0
|
||||
num_block = 0
|
||||
intervals = []
|
||||
|
||||
for idx, w in enumerate(weights):
|
||||
if weights[idx] - prefix > value:
|
||||
intervals.append((prev, idx))
|
||||
prev = idx
|
||||
prefix = weights[idx - 1]
|
||||
num_block += 1
|
||||
|
||||
intervals.append((prev, len(weights)))
|
||||
return num_block + 1, intervals
|
||||
|
||||
|
||||
def _binary_search(weights, num):
|
||||
length = len(weights)
|
||||
prefix = [1 if w == 0 else w for w in weights]
|
||||
for i in range(1, length):
|
||||
prefix[i] += prefix[i - 1]
|
||||
|
||||
lower_bound = max(weights)
|
||||
upper_bound = prefix[length - 1]
|
||||
|
||||
while upper_bound > lower_bound:
|
||||
mid = (upper_bound + lower_bound) // 2
|
||||
number, _ = _calc_partitions(prefix, mid)
|
||||
if number <= num:
|
||||
upper_bound = mid
|
||||
else:
|
||||
lower_bound = mid + 1
|
||||
|
||||
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||||
if num_block < num:
|
||||
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
|
||||
|
||||
logger = get_dist_logger()
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
partition_items = num_items // num_chunks
|
||||
for idx in range(num_chunks):
|
||||
base_idx = idx * partition_items
|
||||
chunk_size = partition_items // pipeline_parallel_size
|
||||
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||||
if chunk_size == 0:
|
||||
logger.warning("Some nodes in Pipeline have no requests")
|
||||
|
||||
for p in range(pipeline_parallel_size):
|
||||
st = base_idx
|
||||
base_idx += chunk_size + (p >= left)
|
||||
parts[p].append((st, base_idx))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||
num_total = pipeline_parallel_size * num_chunks
|
||||
num_items = len(weights)
|
||||
if num_items <= num_total:
|
||||
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||||
|
||||
intervals = _binary_search(weights, num_total)
|
||||
|
||||
current = 0
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
for inter in intervals:
|
||||
parts[current].append(inter)
|
||||
current = (current + 1) % pipeline_parallel_size
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def build_kwargs_for_module(function, input_tensor, kw_dict):
|
||||
"""
|
||||
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
||||
Therefore, we just filter the kwargs from second element of the dictionary.
|
||||
"""
|
||||
sig = inspect.signature(function)
|
||||
if input_tensor is None:
|
||||
kwargs_offset = 0
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
kwargs_offset = 1
|
||||
elif isinstance(input_tensor, (tuple, OrderedDict)):
|
||||
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
# Huggingface will take their own structures based on OrderedDict as the output
|
||||
# between layers so we've to close this check.
|
||||
kwargs_offset = len(input_tensor)
|
||||
args_name_list = list(sig.parameters.keys())
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
|
||||
if len(kw_dict) == 0:
|
||||
return None
|
||||
return kw_dict
|
||||
|
||||
|
||||
def build_kwargs_for_function(function, kw_dict):
|
||||
sig = inspect.signature(function)
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
|
||||
if len(kw_dict) == 0:
|
||||
return None
|
||||
return kw_dict
|
||||
|
||||
|
||||
def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||||
"""
|
||||
We suppose the callable object passed to to_layer_list method in two purpose:
|
||||
a. use the callable object to modify input tensor, such as \
|
||||
lambda x: torch.flatten(x, 1)
|
||||
b. use the callable object to modify kwargs value, such as \
|
||||
def foo(attention_mask=None):
|
||||
if attention_mask is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
return attention_mask
|
||||
"""
|
||||
|
||||
if kw_dict is not None:
|
||||
rst = func(**kw_dict)
|
||||
if isinstance(rst, tuple):
|
||||
for i, k in enumerate(kw_dict.keys()):
|
||||
kwargs[k] = rst[i]
|
||||
else:
|
||||
for k in kw_dict.keys():
|
||||
kwargs[k] = rst
|
||||
return input_tensor
|
||||
if isinstance(input_tensor, tuple):
|
||||
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
|
||||
sig = inspect.signature(func)
|
||||
func_args_num = len(sig.parameters)
|
||||
assert func_args_num <= len(
|
||||
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
|
||||
if func_args_num < len(input_tensor):
|
||||
return func(*input_tensor[:func_args_num])
|
||||
else:
|
||||
return func(*input_tensor)
|
||||
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
|
||||
return func(input_tensor)
|
||||
|
||||
|
||||
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||||
|
||||
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||||
funcs_to_exec = func_dict[func_key]
|
||||
if isinstance(funcs_to_exec, list):
|
||||
for f in funcs_to_exec:
|
||||
f_kwargs = build_kwargs_for_function(f, kwargs)
|
||||
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
|
||||
else:
|
||||
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
|
||||
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def call_module(module, args=None, kwargs=None):
|
||||
if args is None:
|
||||
args = ()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if isinstance(module, CheckpointModule):
|
||||
forward_func = module._forward
|
||||
else:
|
||||
forward_func = module.forward
|
||||
sig = inspect.signature(forward_func)
|
||||
param_nums = len(sig.parameters)
|
||||
feed_nums = len(args) + len(kwargs)
|
||||
args_needed_nums = param_nums - len(kwargs)
|
||||
args_needed = args[:args_needed_nums]
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
return module(*args_needed, *convert_kwargs_to_args)
|
||||
else:
|
||||
return module(*args_needed, **kwargs)
|
||||
|
||||
|
||||
def customized_partition(exec_seq):
|
||||
'''
|
||||
This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
|
||||
annotation to note the partition point.
|
||||
'''
|
||||
customized_parts = {}
|
||||
start = 0
|
||||
stop = 0
|
||||
rank = 0
|
||||
for element in exec_seq:
|
||||
if isinstance(element, str):
|
||||
if element == 'SPLIT_NODE':
|
||||
customized_parts[rank] = [(start, stop)]
|
||||
start = stop
|
||||
rank += 1
|
||||
else:
|
||||
stop += 1
|
||||
customized_parts[rank] = [(start, stop)]
|
||||
return customized_parts
|
Reference in New Issue
Block a user