mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[pipeline] refactor the pipeline module (#1087)
* [pipeline] refactor the pipeline module * polish code
This commit is contained in:
4
colossalai/pipeline/__init__.py
Normal file
4
colossalai/pipeline/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .pipelinable import PipelinableContext, PipelinableModel
|
||||
from .layer_sepc import LayerSpec
|
||||
|
||||
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
|
55
colossalai/pipeline/layer_sepc.py
Normal file
55
colossalai/pipeline/layer_sepc.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from colossalai.utils.model.utils import call_to_str
|
||||
|
||||
class LayerSpec:
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, typename, *module_args, **module_kwargs):
|
||||
self.typename = typename
|
||||
self.module_args = module_args
|
||||
self.module_kwargs = module_kwargs
|
||||
self.children = None
|
||||
self._param_count = 0
|
||||
|
||||
if not issubclass(typename, torch.nn.Module):
|
||||
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
||||
|
||||
def __repr__(self):
|
||||
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
||||
|
||||
@property
|
||||
def param_count(self):
|
||||
return self._param_count
|
||||
|
||||
def build(self):
|
||||
"""Build the stored specification."""
|
||||
|
||||
recovered_args = []
|
||||
for obj in self.module_args:
|
||||
if isinstance(obj, LayerSpec):
|
||||
obj = obj.build()
|
||||
recovered_args.append(obj)
|
||||
recovered_args = tuple(recovered_args)
|
||||
|
||||
recovered_kwargs = {}
|
||||
for k, v in self.module_kwargs.items():
|
||||
if isinstance(v, LayerSpec):
|
||||
v = v.build()
|
||||
recovered_kwargs[k] = v
|
||||
|
||||
return self.typename(*recovered_args, **recovered_kwargs)
|
||||
|
||||
def set_children(self, children):
|
||||
self.children = children
|
||||
|
||||
def count_params(self):
|
||||
self._param_count = 0
|
||||
layer = self.build()
|
||||
for param in layer.parameters():
|
||||
self._param_count += param.numel()
|
||||
return self._param_count
|
||||
|
||||
def reset_param_count(self):
|
||||
self._param_count = 0
|
260
colossalai/pipeline/pipelinable.py
Normal file
260
colossalai/pipeline/pipelinable.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import torch
|
||||
import inspect
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from colossalai.tensor import ColoParameter
|
||||
from .layer_sepc import LayerSpec
|
||||
|
||||
|
||||
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
A context manager to split the model into pipeline stages.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: str="balanced"):
|
||||
super().__init__()
|
||||
self._layer_spec_dict = {}
|
||||
self._root_children = None
|
||||
self._model = None
|
||||
self._layer_spec_list = []
|
||||
self._func_dict = {}
|
||||
self._policy = policy
|
||||
|
||||
@property
|
||||
def policy(self):
|
||||
return self._policy
|
||||
|
||||
@policy.setter
|
||||
def policy(self, policy: str):
|
||||
self._policy = policy
|
||||
|
||||
@property
|
||||
def layers_count(self):
|
||||
return len(self._layer_spec_list)
|
||||
|
||||
@property
|
||||
def funcs_count(self):
|
||||
return len(self._func_dict)
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
"""
|
||||
# reserve rng states
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""
|
||||
The callback function when exiting context.
|
||||
"""
|
||||
|
||||
# reset rng states
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
# iterate over the positional arguments
|
||||
# to check if an argument is a torch Module
|
||||
# if found any torch Module, replace it with its layer spec
|
||||
# for storage purpose
|
||||
modified_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
arg = self._layer_spec_dict[id(arg)]
|
||||
modified_args.append(arg)
|
||||
|
||||
# to the same for the keyword arguments
|
||||
modified_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.nn.Module):
|
||||
v = self._layer_spec_dict[id(v)]
|
||||
# (lyl)TODO: analyse ColoTensor as well
|
||||
modified_kwargs[k] = v
|
||||
|
||||
# keep track of the module children
|
||||
# as torch.nn.Module.__init__ is called from inner module to outer module,
|
||||
# the final value of self._model will be the outermost model
|
||||
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
|
||||
# will be the ``ResNet`` object.
|
||||
self._root_children = list(module.children())
|
||||
self._model = module
|
||||
|
||||
# store the children to keep the module hierarchy
|
||||
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
|
||||
layer_spec.set_children(module.children())
|
||||
|
||||
# store the layer spec in this context
|
||||
module_id = id(module)
|
||||
self._layer_spec_dict[module_id] = layer_spec
|
||||
|
||||
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
|
||||
name_list = []
|
||||
for name, param in module.named_parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
continue
|
||||
name_list.append((name, param))
|
||||
|
||||
for name, param in name_list:
|
||||
delattr(module, name)
|
||||
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
|
||||
|
||||
def to_layer_list(self, exec_seq=None):
|
||||
"""
|
||||
Create a layer spec list and func list with execution sequence given by user.
|
||||
If exec_seq is None, we will take the module initizing order as execution order.
|
||||
"""
|
||||
if exec_seq is None:
|
||||
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
|
||||
children_name = []
|
||||
for child in self._root_children:
|
||||
layer_spec = self._layer_spec_dict[id(child)]
|
||||
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
|
||||
torch.nn.modules.container.Sequential):
|
||||
for child_in_container in layer_spec.children:
|
||||
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
||||
for name, module in self._model.named_modules():
|
||||
if id(module) == id(child_in_container):
|
||||
children_name.append(name)
|
||||
break
|
||||
else:
|
||||
self._layer_spec_list.append(layer_spec)
|
||||
for name, module in self._model.named_modules():
|
||||
if id(module) == id(child):
|
||||
children_name.append(name)
|
||||
break
|
||||
|
||||
else:
|
||||
front_funcs_list = []
|
||||
named_modules = dict(self._model.named_modules())
|
||||
for index, element in enumerate(exec_seq):
|
||||
if isinstance(element, str):
|
||||
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
|
||||
|
||||
# get the layer spec based on the module ID
|
||||
module = named_modules[element]
|
||||
layer_spec = self._layer_spec_dict[id(module)]
|
||||
|
||||
# check whether there are functions which should be executed before this module
|
||||
if len(front_funcs_list) != 0:
|
||||
func_key = (layer_spec, "front")
|
||||
if func_key not in self._func_dict:
|
||||
self._func_dict[func_key] = []
|
||||
for f in front_funcs_list:
|
||||
self._func_dict[func_key].append(f)
|
||||
front_funcs_list = []
|
||||
|
||||
func_key = (layer_spec, "behind")
|
||||
self._layer_spec_list.append(layer_spec)
|
||||
elif isinstance(element, tuple) and element[1] == "front":
|
||||
front_funcs_list.append(element[0])
|
||||
else:
|
||||
if func_key not in self._func_dict:
|
||||
self._func_dict[func_key] = []
|
||||
if isinstance(element, tuple):
|
||||
self._func_dict[func_key].append(element[0])
|
||||
else:
|
||||
self._func_dict[func_key].append(element)
|
||||
|
||||
def partition(self, num_chunks, pipeline_size, rank):
|
||||
"""
|
||||
Partitioned model will be built respect to partion policy.
|
||||
The real module instance will be built in this method.
|
||||
"""
|
||||
if isinstance(self._policy, str):
|
||||
if self._policy == "uniform":
|
||||
parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
|
||||
elif self._policy == "balanced":
|
||||
param_counts = []
|
||||
for layer_spec in self._layer_spec_list:
|
||||
param_counts.append(layer_spec.count_params())
|
||||
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
|
||||
else:
|
||||
raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].")
|
||||
elif isinstance(self._policy, dict):
|
||||
parts = self._policy[rank]
|
||||
else:
|
||||
raise ValueError("A partition policy should be either a string or a dictionary.")
|
||||
|
||||
layers_to_build = []
|
||||
for start, end in parts:
|
||||
layers_to_build += self._layer_spec_list[start:end]
|
||||
behind_func_dict_in_partition = {}
|
||||
front_func_dict_in_partition = {}
|
||||
module_list_in_partition = []
|
||||
for layer in layers_to_build:
|
||||
module = layer.build()
|
||||
module_list_in_partition.append(module)
|
||||
if (layer, "front") in self._func_dict:
|
||||
front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")]
|
||||
elif (layer, "behind") in self._func_dict:
|
||||
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
|
||||
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
||||
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
|
||||
behind_func_dict_in_partition)
|
||||
|
||||
return pipeline_model
|
||||
|
||||
|
||||
class PipelinableModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, module_list, front_func_dict, behind_func_dict):
|
||||
super().__init__()
|
||||
self._module_list = module_list
|
||||
self._front_func_dict = front_func_dict
|
||||
self._behind_func_dict = behind_func_dict
|
||||
|
||||
def forward(self, input_tensor, **kwargs):
|
||||
|
||||
for module in self._module_list:
|
||||
|
||||
if id(module) in self._front_func_dict:
|
||||
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
|
||||
|
||||
if isinstance(module, CheckpointModule):
|
||||
forward_func = module._forward
|
||||
else:
|
||||
forward_func = module.forward
|
||||
if input_tensor is None:
|
||||
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
|
||||
else:
|
||||
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
|
||||
if module_kwargs is not None and input_tensor is not None:
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in module_kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||
else:
|
||||
rst = module(input_tensor, **module_kwargs)
|
||||
if isinstance(rst, tuple):
|
||||
input_tensor = rst[0]
|
||||
else:
|
||||
input_tensor = rst
|
||||
elif module_kwargs is not None and input_tensor is None:
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in module_kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||
else:
|
||||
rst = module(**module_kwargs)
|
||||
if isinstance(rst, tuple):
|
||||
input_tensor = rst[0]
|
||||
else:
|
||||
input_tensor = rst
|
||||
else:
|
||||
input_tensor = module(input_tensor)
|
||||
|
||||
if id(module) in self._behind_func_dict:
|
||||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
|
207
colossalai/pipeline/utils.py
Normal file
207
colossalai/pipeline/utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import heapq
|
||||
import inspect
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from typing import List
|
||||
|
||||
def _binary_partition(weights: List, start: int, end: int):
|
||||
"""Returns the binary partition position of `weights`, given the start
|
||||
position `st` and the end position `ed`.
|
||||
|
||||
Args:
|
||||
weights (list): A python list to be binary partitioned
|
||||
start (int): the start position of the binary partition
|
||||
end (int): the end position of the binary partition
|
||||
|
||||
Returns:
|
||||
int: the binary partition position of `weights`
|
||||
"""
|
||||
w_sum = weights[end - 1]
|
||||
prefix = 0
|
||||
if start > 0:
|
||||
w_sum -= weights[start - 1]
|
||||
prefix = weights[start - 1]
|
||||
minimum = float("inf")
|
||||
for idx in range(start + 1, end):
|
||||
front = weights[idx - 1] - prefix
|
||||
diff = abs(w_sum - 2 * front)
|
||||
if diff < minimum:
|
||||
pos = idx
|
||||
minimum = diff
|
||||
|
||||
return start, pos, end
|
||||
|
||||
|
||||
def _heap_addition(weights: List, intervals: int, add_cnt: int):
|
||||
"""
|
||||
"""
|
||||
|
||||
def _heap_push(heap, st, ed):
|
||||
value = weights[ed - 1]
|
||||
if st > 0:
|
||||
value -= weights[st - 1]
|
||||
heapq.heappush(heap, (-value, st, ed))
|
||||
|
||||
ret_intervals = []
|
||||
heap = []
|
||||
|
||||
for st, ed in intervals:
|
||||
_heap_push(heap, st, ed)
|
||||
|
||||
while add_cnt > 0:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
if ed - st == 1:
|
||||
ret_intervals.append((st, ed))
|
||||
else:
|
||||
l, m, r = _binary_partition(weights, st, ed)
|
||||
_heap_push(heap, l, m)
|
||||
_heap_push(heap, m, r)
|
||||
add_cnt -= 1
|
||||
|
||||
while heap:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
ret_intervals.append((st, ed))
|
||||
|
||||
ret_intervals.sort()
|
||||
return ret_intervals
|
||||
|
||||
|
||||
def _calc_partitions(weights, value):
|
||||
prev = 0
|
||||
prefix = 0
|
||||
num_block = 0
|
||||
intervals = []
|
||||
|
||||
for idx, w in enumerate(weights):
|
||||
if weights[idx] - prefix > value:
|
||||
intervals.append((prev, idx))
|
||||
prev = idx
|
||||
prefix = weights[idx - 1]
|
||||
num_block += 1
|
||||
|
||||
intervals.append((prev, len(weights)))
|
||||
return num_block + 1, intervals
|
||||
|
||||
|
||||
def _binary_search(weights, num):
|
||||
length = len(weights)
|
||||
prefix = [1 if w == 0 else w for w in weights]
|
||||
for i in range(1, length):
|
||||
prefix[i] += prefix[i - 1]
|
||||
|
||||
lower_bound = max(weights)
|
||||
upper_bound = prefix[length - 1]
|
||||
|
||||
while upper_bound > lower_bound:
|
||||
mid = (upper_bound + lower_bound) // 2
|
||||
number, _ = _calc_partitions(prefix, mid)
|
||||
if number <= num:
|
||||
upper_bound = mid
|
||||
else:
|
||||
lower_bound = mid + 1
|
||||
|
||||
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||||
if num_block < num:
|
||||
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||
|
||||
logger = get_dist_logger()
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
partition_items = num_items // num_chunks
|
||||
for idx in range(num_chunks):
|
||||
base_idx = idx * partition_items
|
||||
chunk_size = partition_items // pipeline_parallel_size
|
||||
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||||
if chunk_size == 0:
|
||||
logger.warning("Some nodes in Pipeline have no requests")
|
||||
|
||||
for p in range(pipeline_parallel_size):
|
||||
st = base_idx
|
||||
base_idx += chunk_size + (p >= left)
|
||||
parts[p].append((st, base_idx))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||
num_total = pipeline_parallel_size * num_chunks
|
||||
num_items = len(weights)
|
||||
if num_items <= num_total:
|
||||
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||||
|
||||
intervals = _binary_search(weights, num_total)
|
||||
|
||||
current = 0
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
for inter in intervals:
|
||||
parts[current].append(inter)
|
||||
current = (current + 1) % pipeline_parallel_size
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def build_kwargs_for_module(function, kw_dict):
|
||||
"""
|
||||
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
||||
Therefore, we just filter the kwargs from second element of the dictionary.
|
||||
"""
|
||||
sig = inspect.signature(function)
|
||||
if len(sig.parameters) <= 1:
|
||||
return None
|
||||
args_name_list = list(sig.parameters.keys())
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
|
||||
return kw_dict
|
||||
|
||||
|
||||
def build_kwargs_for_function(function, kw_dict):
|
||||
sig = inspect.signature(function)
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
|
||||
if len(kw_dict) == 0:
|
||||
return None
|
||||
return kw_dict
|
||||
|
||||
|
||||
def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||||
"""
|
||||
We suppose the callable object passed to to_layer_list method in two purpose:
|
||||
a. use the callable object to modify input tensor, such as \
|
||||
lambda x: torch.flatten(x, 1)
|
||||
b. use the callable object to modify kwargs value, such as \
|
||||
def foo(attention_mask=None):
|
||||
if attention_mask is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
return attention_mask
|
||||
"""
|
||||
|
||||
if kw_dict is not None:
|
||||
rst = func(**kw_dict)
|
||||
if isinstance(rst, tuple):
|
||||
for i, k in enumerate(kw_dict.keys()):
|
||||
kwargs[k] = rst[i]
|
||||
else:
|
||||
for k in kw_dict.keys():
|
||||
kwargs[k] = rst
|
||||
return input_tensor
|
||||
return func(input_tensor)
|
||||
|
||||
|
||||
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||||
|
||||
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||||
funcs_to_exec = func_dict[func_key]
|
||||
if isinstance(funcs_to_exec, list):
|
||||
for f in funcs_to_exec:
|
||||
f_kwargs = build_kwargs_for_function(f, kwargs)
|
||||
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
|
||||
else:
|
||||
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
|
||||
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
Reference in New Issue
Block a user