[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -1,4 +1,11 @@
from .pipelinable import PipelinableContext, PipelinableModel
from .layer_spec import LayerSpec
from .p2p import PipelineP2PCommunication
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
from .stage_manager import PipelineStageManager
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
__all__ = [
'PipelineSchedule',
'OneForwardOneBackwardSchedule',
'InterleavedSchedule',
'PipelineP2PCommunication',
'PipelineStageManager',
]

View File

@@ -1,55 +0,0 @@
import torch
from colossalai.utils.model.utils import call_to_str
class LayerSpec:
"""
"""
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
self.module_kwargs = module_kwargs
self.children = None
self._param_count = 0
if not issubclass(typename, torch.nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
def __repr__(self):
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
@property
def param_count(self):
return self._param_count
def build(self):
"""Build the stored specification."""
recovered_args = []
for obj in self.module_args:
if isinstance(obj, LayerSpec):
obj = obj.build()
recovered_args.append(obj)
recovered_args = tuple(recovered_args)
recovered_kwargs = {}
for k, v in self.module_kwargs.items():
if isinstance(v, LayerSpec):
v = v.build()
recovered_kwargs[k] = v
return self.typename(*recovered_args, **recovered_kwargs)
def set_children(self, children):
self.children = children
def count_params(self):
self._param_count = 0
layer = self.build()
for param in layer.parameters():
self._param_count += param.numel()
return self._param_count
def reset_param_count(self):
self._param_count = 0

View File

@@ -1,3 +0,0 @@
from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']

View File

@@ -1,3 +0,0 @@
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']

View File

@@ -1,145 +0,0 @@
from torch.fx.graph_module import GraphModule
from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
import torch
def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
partition_id = 0
elif is_output:
partition_id = 1
else:
prefix = 'submod_'
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
# temporary_val = submod1[0]
# submod2 = call_module(temporary_val, ...)
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
# search in input
if direct_def and input_partitions is not None:
partition_id = partition_name_to_id('', is_input=True)
for i, input_node in enumerate(input_partitions):
if input_node == node:
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
return p_input_val
# search submod in mid part
if direct_def:
for partition in partitions:
if partition == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
return p_input_val
# search temporary value in graph
else:
for partition in partitions:
for offset, mid_val in enumerate(partition.users):
if mid_val == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val
return p_input_val
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
direct_use = not user.name.startswith('getitem')
# user is mid partition
for partition in partitions:
# direct call
if direct_use:
if user == partition:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
# getitem call
else:
if user in partition.args:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break
# user is output
if output_partitions is not None:
output_node = output_partitions[0]
if user.op == output_node.op:
output_keys = {}
partition_id = partition_name_to_id('', is_output=True)
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
for i, arg in enumerate(output_keys):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
return p_output_val
def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
else:
continue
# set output for input_partition
topo_input_partition = Partition()
for partition in input_partitions:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
for arg in partition.args:
cur_node = arg
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
topo_mid_partition.add_input_val(p_input_val)
# set output for submodule
direct_use = True
for user in partition.users:
if user.name.startswith('getitem'):
direct_use = False
break
if direct_use:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
else:
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)
return topo

View File

@@ -1,206 +0,0 @@
from typing import Dict, List
from dataclasses import dataclass
# This file includes data structure used by Pipeline Middleware.
@dataclass
class ValPosition:
partition_id: int
offset: int
def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionInputVal(object):
def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
self._from_partition_and_offset: ValPosition = val_pos
def get(self):
return self._from_partition_and_offset
def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionOutputVal(object):
def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []
def add(self, partition_id, offset):
val_pos = ValPosition(partition_id, offset)
self._to_partition_and_offset.append(val_pos)
def get(self):
return self._to_partition_and_offset
def __str__(self) -> str:
res = ''
res += '->('
for val_pos in self._to_partition_and_offset:
res += f'{val_pos},'
res += ')'
return res
def __repr__(self) -> str:
return self.__str__()
class Partition(object):
def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []
def add_input_val(self, input_val: PartitionInputVal):
self._input_vals.append(input_val)
def add_output_val(self, output_val: PartitionOutputVal):
self._output_vals.append(output_val)
def get_input_vals(self):
return self._input_vals
def get_output_vals(self):
return self._output_vals
# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
for offset, output_val in enumerate(self._output_vals):
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)
return res
# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
for input_val in self._input_vals:
val_pos = input_val.get()
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
for output_val in self._output_vals:
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'
res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'
return res
def __repr__(self) -> str:
return self.__str__()
# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
# There are three kinds of partition in Pipeline Middleware Design
# which represents the whole process of a model execution: input-fwd-output
# 1. input_partition: records the input of a model.
# 2. mid_partition: record the splitted forwards execution of a model.
# 3. output_partition: records the output of a model.
# attributes:
# _partitions: include all partitions
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id
def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id
def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id
def get_input_partition_id(self):
return self._input_partition_id
def get_output_partition_id(self):
return self._output_partition_id
def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition
def get_mid_partitions(self):
res = {} #{partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
res[partition_id] = partition
return res
def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())
def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
return None
def get_output_partition(self):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
return None
def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]
def __str__(self) -> str:
res = ''
if len(self._partitions) == 0:
return 'Empty Topo Graph.'
input_part = self.get_input_partition()
if input_part is not None:
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'
mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'
output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'
return res
def __repr__(self) -> str:
return self.__str__()

View File

@@ -1,263 +0,0 @@
import inspect
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .layer_spec import LayerSpec
from .utils import (
build_kwargs_for_function,
build_kwargs_for_module,
call_module,
customized_partition,
exec_func_with_kwargs,
exec_funcs_with_kwargs,
partition_balanced,
partition_uniform,
)
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
A context manager to split the model into pipeline stages.
"""
def __init__(self, policy: str = "balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
self._model = None
self._layer_spec_list = []
self._func_dict = {}
self._policy = policy
@property
def policy(self):
return self._policy
@policy.setter
def policy(self, policy: str):
self._policy = policy
@property
def layers_count(self):
return len(self._layer_spec_list)
@property
def funcs_count(self):
return len(self._func_dict)
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
# reserve rng states
self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
def _post_context_exec(self):
"""
The callback function when exiting context.
"""
# reset rng states
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id(arg) in self._layer_spec_dict:
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
# to the same for the keyword arguments
modified_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, torch.nn.Module):
v = self._layer_spec_dict[id(v)]
# (lyl)TODO: analyze ColoTensor as well
modified_kwargs[k] = v
# keep track of the module children
# as torch.nn.Module.__init__ is called from inner module to outer module,
# the final value of self._model will be the outermost model
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
# will be the ``ResNet`` object.
self._root_children = list(module.children())
self._model = module
# store the children to keep the module hierarchy
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
layer_spec.set_children(module.children())
# store the layer spec in this context
module_id = id(module)
self._layer_spec_dict[module_id] = layer_spec
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
name_list = []
for name, param in module.named_parameters():
if isinstance(param, ColoParameter):
continue
name_list.append((name, param))
for name, param in name_list:
if hasattr(module, name):
delattr(module, name)
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
def to_layer_list(self, exec_seq=None):
"""
Create a layer spec list and func list with execution sequence given by user.
If exec_seq is None, we will take the module initializing order as execution order.
"""
self._exec_seq = exec_seq
if exec_seq is None:
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
children_name = []
for child in self._root_children:
layer_spec = self._layer_spec_dict[id(child)]
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
torch.nn.modules.container.Sequential):
for child_in_container in layer_spec.children:
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
for name, module in self._model.named_modules():
if id(module) == id(child_in_container):
children_name.append(name)
break
else:
self._layer_spec_list.append(layer_spec)
for name, module in self._model.named_modules():
if id(module) == id(child):
children_name.append(name)
break
else:
front_funcs_list = []
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq):
if isinstance(element, str):
if element == 'SPLIT_NODE':
continue
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
# get the layer spec based on the module ID
module = named_modules[element]
layer_spec = self._layer_spec_dict[id(module)]
# check whether there are functions which should be executed before this module
if len(front_funcs_list) != 0:
func_key = (layer_spec, "front")
if func_key not in self._func_dict:
self._func_dict[func_key] = []
for f in front_funcs_list:
self._func_dict[func_key].append(f)
front_funcs_list = []
func_key = (layer_spec, "behind")
self._layer_spec_list.append(layer_spec)
elif isinstance(element, tuple) and element[1] == "front":
front_funcs_list.append(element[0])
else:
if func_key not in self._func_dict:
self._func_dict[func_key] = []
if isinstance(element, tuple):
self._func_dict[func_key].append(element[0])
else:
self._func_dict[func_key].append(element)
def partition(self, num_chunks, pipeline_size, rank):
"""
Partitioned model will be built respect to partition policy.
The real module instance will be built in this method.
"""
if isinstance(self._policy, str):
if self._policy == "uniform":
parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
elif self._policy == "balanced":
param_counts = []
for layer_spec in self._layer_spec_list:
param_counts.append(layer_spec.count_params())
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
elif self._policy == "customized":
assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.'
self.customized_parts = customized_partition(self._exec_seq)
assert len(self.customized_parts) == gpc.get_world_size(
ParallelMode.PIPELINE
), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}'
parts = self.customized_parts[rank]
else:
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
elif isinstance(self._policy, dict):
parts = self._policy[rank]
else:
raise ValueError("A partition policy should be either a string or a dictionary.")
layers_to_build = []
for start, end in parts:
layers_to_build += self._layer_spec_list[start:end]
behind_func_dict_in_partition = {}
front_func_dict_in_partition = {}
module_list_in_partition = []
for layer in layers_to_build:
module = layer.build()
module_list_in_partition.append(module)
if (layer, "front") in self._func_dict:
front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")]
elif (layer, "behind") in self._func_dict:
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
behind_func_dict_in_partition)
return pipeline_model
class PipelinableModel(torch.nn.Module):
def __init__(self, module_list, front_func_dict, behind_func_dict):
super().__init__()
self._module_list = module_list
self._front_func_dict = front_func_dict
self._behind_func_dict = behind_func_dict
def forward(self, *input_tensor, **kwargs):
for module in self._module_list:
if id(module) in self._front_func_dict:
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
if isinstance(module, CheckpointModule):
forward_func = module._forward
else:
forward_func = module.forward
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
if input_tensor is None:
input_tensor = call_module(module, kwargs=module_kwargs)
elif isinstance(input_tensor, torch.Tensor):
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
else:
input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
if id(module) in self._behind_func_dict:
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
return input_tensor

View File

@@ -1,168 +0,0 @@
from typing import List, Dict, Tuple
import os
import threading
from torch.distributed import rpc
import torch.distributed as dist
from colossalai.tensor import ProcessGroup
class PipelineProcessGroup:
# TODO : flexible API for DP size and TP size
# In the future design mode, dp_degree and tp_degree should be removed
def __init__(self) -> None:
self.is_initialize = False
def set_global_info(self,
rank: int,
world_size: int,
dp_degree: int = 1,
tp_degree: int = 1,
num_worker_threads: int = 1,
device: str = "cuda") -> None:
device_mesh_size = dp_degree * tp_degree
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
self._num_worker_threads = num_worker_threads
self._device_mesh_size = device_mesh_size
self._rank = rank
self._world_size = world_size
self._dp_degree = dp_degree
self._tp_degree = tp_degree
self.device = device
self._stage_num = world_size // device_mesh_size
self._pp_rank = rank // device_mesh_size
self._pp_ranks = [(rank % device_mesh_size) + i * device_mesh_size for i in range(self._stage_num)]
self._local_stage_ranks = [(rank // device_mesh_size * device_mesh_size) + i for i in range(device_mesh_size)]
# pp_ranks
self._initialize_pp_process_group()
# initialise tp dp process groups
self._initialize_tp_dp_process_group()
# status
self._is_first_pp_rank = self._pp_rank == 0
self._is_last_pp_rank = self._pp_rank == self._stage_num - 1
self.is_initialize = True
# lock
self.initialise_lock = threading.Lock()
self.chimera_lock = threading.Lock()
def _initialize_process_group(self):
stage_num = self.get_stage_num()
if stage_num == 1:
return
device = self.device
world_size = self.get_world_size()
rank = self.get_global_rank()
backend = 'nccl' if device == 'cuda' else 'gloo'
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group')
def _initialize_pp_process_group(self) -> None:
rank = self.get_global_rank()
world_size = self.get_world_size()
# build rpc connection
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads)
for pp_rank in self._pp_ranks:
options.set_device_map(f'work{pp_rank}', {rank: pp_rank})
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
def _initialize_tp_dp_process_group(self) -> None:
rank = self.get_global_rank()
local_stage_ranks = self.get_local_stage_global_ranks()
dp_degree = self.get_dp_degree()
tp_degree = self.get_tp_degree()
self._tp_dp_process_group = ProcessGroup(rank, local_stage_ranks, tp_degree, dp_degree)
def get_global_rank(self):
return self._rank
def get_world_size(self):
return self._world_size
def get_dp_degree(self) -> int:
return self._dp_degree
def get_tp_degree(self) -> int:
return self._tp_degree
def get_local_device_mesh_size(self) -> int:
return self._device_mesh_size
def get_device_mesh_num(self) -> int:
pass
def get_stage_num(self) -> int:
return self._stage_num
def is_first_stage(self) -> bool:
return self._is_first_pp_rank
def is_last_stage(self) -> bool:
return self._is_last_pp_rank
def check_pp_rank_valid(self, pp_rank: int) -> bool:
return -1 < pp_rank < self._stage_num
def get_local_pp_rank(self) -> int:
return self._pp_rank
def get_prev_pp_rank(self) -> int:
prev_pp_rank = self._pp_rank - 1
if not self.check_pp_rank_valid(prev_pp_rank):
assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a previous stage!")
return prev_pp_rank
def get_next_pp_rank(self) -> int:
next_pp_rank = self._pp_rank + 1
if not self.check_pp_rank_valid(next_pp_rank):
assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a next stage!")
return next_pp_rank
def get_local_stage_global_ranks(self) -> List[int]:
return self._local_stage_ranks
def local_dp_rank(self) -> int:
return self._tp_dp_process_group.dp_local_rank()
def local_tp_rank(self) -> int:
return self._tp_dp_process_group.tp_local_rank()
def get_pp_global_ranks(self) -> int:
return self._pp_ranks
def get_dp_global_ranks(self):
pass
def get_tp_global_ranks(self):
pass
def get_chimera_all_reduce_group(self, pp_rank: int):
with self.chimera_lock:
if not hasattr(self, 'chimera_groups'):
world_size = self.get_world_size()
stage_num = self.get_stage_num()
assert world_size % 2 == 0, 'world_size must be even in chimera!'
self.chimera_groups = {}
for rank in range(world_size // 2):
pair = [rank, world_size - 1 - rank]
group = dist.new_group(pair)
self.chimera_groups[pair[0]] = group
self.chimera_groups[pair[1]] = group
self.chimera_groups[pair[0] + stage_num] = group
self.chimera_groups[pair[1] + stage_num] = group
self.chimera_step_lock = threading.Lock()
self.chimera_step_lock.acquire()
return self.chimera_groups[pp_rank]
ppg = PipelineProcessGroup()

View File

@@ -1,4 +0,0 @@
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
from .utils import pytree_map
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']

File diff suppressed because it is too large Load Diff

View File

@@ -1,346 +0,0 @@
import threading
from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
class FillDrainWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
num_microbatches = self.num_microbatches
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
return target_key
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
class OneFOneBWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase)
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
if target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
return target_key
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
class ChimeraWorker(WorkerBase):
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be arranged in order, the order of the input of current forward
self.producer_stage_ids = []
self.consumer_stage_ids = []
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= min_pp_rank:
self.producer_stage_ids.append(prev_rank)
if next_rank <= max_pp_rank:
self.consumer_stage_ids.append(next_rank)
def _get_work_item_key(self) -> UniqueKey:
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
real_microbatch_num = self.num_microbatches // 2
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
forward_block_num = self.forward_times // forward_block_size
if self.forward_times >= real_microbatch_num or \
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
real_target_microbatch_id = target_microbatch_id * 2
if pp_rank >= stage_num:
real_target_microbatch_id += 1
target_key = UniqueKey(real_target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def _initialize_partition(self):
# In order to ensure the down pipeline share the same parameter
# with the up pipeline, partition of down partition will be copied
# from corresponding up stage
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
device = self.device
if pp_rank < stage_num:
super()._initialize_partition()
else:
# if it is down pipeline, create partition by origin method
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
# get the corresponding model state dict and wait for its init
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
super()._initialize_partition()
self.module_partition.load_state_dict(state_dict)
# init group for chimera in ppg
ppg.get_chimera_all_reduce_group(pp_rank)
# lock for step sync
self.step_sync_lock = threading.Lock()
self.step_sync_lock.acquire()
self.have_grad_lock = threading.Lock()
self.have_grad_lock.acquire()
def _get_lock_gradient(self):
self.have_grad_lock.acquire()
grads = self.get_parameter_gradients()
self.step_sync_lock.release()
return grads
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
def is_last_stage(self):
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
def _is_last_step(self, work_item: WorkItem) -> bool:
if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
last_microbatch_id = self.num_microbatches - 1
if self.pp_rank < self.actual_stage_num:
last_microbatch_id -= 1
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
return is_last_phase and is_last_microbatch
def _get_step_order(self) -> List[int]:
# TODO : If you want to extend it to multi head chimera, overwrite here
stage_num = self.actual_stage_num
pp_rank = self.pp_rank
# pp_rank in the same device
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
return local_device_pp_ranks
def _hook_before_step(self):
self.have_grad_lock.release()
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
# if current pp_rank is not the first to do step
# wait its previous pp_rank finish step
grads = self.get_parameter_gradients()
# send
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
co_grads = co_worker.rpc_sync()._get_lock_gradient()
# sync
self.step_sync_lock.acquire()
for i in range(len(grads)):
grads[i] += co_grads[i]
class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False
chunk = 1
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
def _create_pp_rank_to_module_partition_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
stage_num = self.stage_num
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
# merge up and down
return {**up_ret_future, **down_ret_future}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_labels(microbatch_id, microlabels)
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
stage_num = self.stage_num
num_microbatches = self.num_microbatches
if not forward_only:
for pp_rank in input_pp_ranks:
up_last_microbatch_id = num_microbatches - 2
down_last_microbatch_id = num_microbatches - 1
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
up_worker_rref.rpc_sync().get_output_by_key(up_key)
down_worker_rref.rpc_sync().get_output_by_key(down_key)
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
"""Logic of collection of forward in Chimera.
Currently, only one input one output model is supported
"""
stage_num = self.stage_num
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
offset = (microbatch_id % 2) * stage_num
ret = ret_future[pp_rank + offset][microbatch_id].wait()
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result

View File

@@ -1,155 +0,0 @@
import argparse
import os
import warnings
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
if isinstance(obj, process_types):
return fn(obj)
elif type(obj) is dict:
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
elif type(obj) is tuple:
return tuple(pyobj_map(o, fn, process_types) for o in obj)
elif type(obj) is list:
return list(pyobj_map(o, fn, process_types) for o in obj)
else:
return obj
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if isinstance(obj, dict):
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
elif isinstance(obj, tuple):
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, list):
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, process_types):
return fn(obj)
else:
return fn(obj) if map_all else obj
def tensor_shape_list(obj):
return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor)
def get_batch_lengths(batch):
lengths = []
pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor)
return lengths
def split_batch(batch: Any, start, stop, device: str):
if device == 'cuda':
fn = lambda x: x[start:stop].cuda()
else:
fn = lambda x: x[start:stop]
return pytree_map(batch, fn=fn, process_types=torch.Tensor)
def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types):
if obj is None:
return None
filters = []
def condition_append(obj):
if fn(obj):
filters.append(obj)
pytree_map(obj, fn=condition_append, process_types=process_types)
return filters
def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
# TODO : combine producer and consumer
# by default, merge all args in the output args or kwargs
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
return args_or_kwargs
def run_worker(rank, args, master_func):
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
device = args.device
world_size = args.world_size
dp_degree = args.dp_degree
tp_degree = args.tp_degree
num_worker_threads = args.num_worker_threads
host = args.master_addr
port = args.master_port
backend = 'nccl' if device == 'cuda' else 'gloo'
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device)
ppg.args = args
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
master_func(args)
# barrier here
if _is_current_rpc_agent_set():
rpc.shutdown()
else:
warnings.warn("RPC has not been initialized")
def rpc_run(args, master_func):
world_size = args.world_size
mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--dp_degree', type=int, default=1)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=int, default=128)
return parser.parse_args()

View File

@@ -1,7 +1,9 @@
from .base import PipelineSchedule
from .interleaved_pp import InterleavedSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule
__all__ = [
'PipelineSchedule',
'OneForwardOneBackwardSchedule',
'InterleavedSchedule',
]

View File

@@ -1,276 +0,0 @@
import heapq
import inspect
from collections import OrderedDict
from typing import List
import torch
from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.logging import get_dist_logger
def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`.
Args:
weights (list): A python list to be binary partitioned
start (int): the start position of the binary partition
end (int): the end position of the binary partition
Returns:
int: the binary partition position of `weights`
"""
w_sum = weights[end - 1]
prefix = 0
if start > 0:
w_sum -= weights[start - 1]
prefix = weights[start - 1]
minimum = float("inf")
for idx in range(start + 1, end):
front = weights[idx - 1] - prefix
diff = abs(w_sum - 2 * front)
if diff < minimum:
pos = idx
minimum = diff
return start, pos, end
def _heap_addition(weights: List, intervals: int, add_cnt: int):
"""
"""
def _heap_push(heap, st, ed):
value = weights[ed - 1]
if st > 0:
value -= weights[st - 1]
heapq.heappush(heap, (-value, st, ed))
ret_intervals = []
heap = []
for st, ed in intervals:
_heap_push(heap, st, ed)
while add_cnt > 0:
_, st, ed = heapq.heappop(heap)
if ed - st == 1:
ret_intervals.append((st, ed))
else:
l, m, r = _binary_partition(weights, st, ed)
_heap_push(heap, l, m)
_heap_push(heap, m, r)
add_cnt -= 1
while heap:
_, st, ed = heapq.heappop(heap)
ret_intervals.append((st, ed))
ret_intervals.sort()
return ret_intervals
def _calc_partitions(weights, value):
prev = 0
prefix = 0
num_block = 0
intervals = []
for idx, w in enumerate(weights):
if weights[idx] - prefix > value:
intervals.append((prev, idx))
prev = idx
prefix = weights[idx - 1]
num_block += 1
intervals.append((prev, len(weights)))
return num_block + 1, intervals
def _binary_search(weights, num):
length = len(weights)
prefix = [1 if w == 0 else w for w in weights]
for i in range(1, length):
prefix[i] += prefix[i - 1]
lower_bound = max(weights)
upper_bound = prefix[length - 1]
while upper_bound > lower_bound:
mid = (upper_bound + lower_bound) // 2
number, _ = _calc_partitions(prefix, mid)
if number <= num:
upper_bound = mid
else:
lower_bound = mid + 1
num_block, intervals = _calc_partitions(prefix, upper_bound)
if num_block < num:
intervals = _heap_addition(prefix, intervals, num - num_block)
return intervals
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // pipeline_parallel_size
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests")
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
return parts
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = pipeline_parallel_size * num_chunks
num_items = len(weights)
if num_items <= num_total:
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total)
current = 0
parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals:
parts[current].append(inter)
current = (current + 1) % pipeline_parallel_size
return parts
def build_kwargs_for_module(function, input_tensor, kw_dict):
"""
Generally, the first argument of module.forward is an input tensor come from the previous layer.
Therefore, we just filter the kwargs from second element of the dictionary.
"""
sig = inspect.signature(function)
if input_tensor is None:
kwargs_offset = 0
elif isinstance(input_tensor, torch.Tensor):
kwargs_offset = 1
elif isinstance(input_tensor, (tuple, OrderedDict)):
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
# Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
if len(kw_dict) == 0:
return None
return kw_dict
def build_kwargs_for_function(function, kw_dict):
sig = inspect.signature(function)
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
if len(kw_dict) == 0:
return None
return kw_dict
def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
"""
We suppose the callable object passed to to_layer_list method in two purpose:
a. use the callable object to modify input tensor, such as \
lambda x: torch.flatten(x, 1)
b. use the callable object to modify kwargs value, such as \
def foo(attention_mask=None):
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
return attention_mask
"""
if kw_dict is not None:
rst = func(**kw_dict)
if isinstance(rst, tuple):
for i, k in enumerate(kw_dict.keys()):
kwargs[k] = rst[i]
else:
for k in kw_dict.keys():
kwargs[k] = rst
return input_tensor
if isinstance(input_tensor, tuple):
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
sig = inspect.signature(func)
func_args_num = len(sig.parameters)
assert func_args_num <= len(
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
if func_args_num < len(input_tensor):
return func(*input_tensor[:func_args_num])
else:
return func(*input_tensor)
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
return func(input_tensor)
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
assert func_key in func_dict, f"{func_key} is not in the function_dict."
funcs_to_exec = func_dict[func_key]
if isinstance(funcs_to_exec, list):
for f in funcs_to_exec:
f_kwargs = build_kwargs_for_function(f, kwargs)
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
else:
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
return input_tensor
def call_module(module, args=None, kwargs=None):
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if isinstance(module, CheckpointModule):
forward_func = module._forward
else:
forward_func = module.forward
sig = inspect.signature(forward_func)
param_nums = len(sig.parameters)
feed_nums = len(args) + len(kwargs)
args_needed_nums = param_nums - len(kwargs)
args_needed = args[:args_needed_nums]
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in kwargs.values():
convert_kwargs_to_args.append(v)
return module(*args_needed, *convert_kwargs_to_args)
else:
return module(*args_needed, **kwargs)
def customized_partition(exec_seq):
'''
This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
annotation to note the partition point.
'''
customized_parts = {}
start = 0
stop = 0
rank = 0
for element in exec_seq:
if isinstance(element, str):
if element == 'SPLIT_NODE':
customized_parts[rank] = [(start, stop)]
start = stop
rank += 1
else:
stop += 1
customized_parts[rank] = [(start, stop)]
return customized_parts