mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[pipeline] refactor the pipeline module (#1087)
* [pipeline] refactor the pipeline module * polish code
This commit is contained in:
@@ -1,12 +1,5 @@
|
||||
from .builder import (build_schedule, build_lr_scheduler, build_model,
|
||||
build_optimizer, build_layer, build_loss, build_hooks,
|
||||
build_dataset, build_transform, build_data_sampler,
|
||||
build_gradient_handler, build_ophooks)
|
||||
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
|
||||
from .builder import build_from_config, build_from_registry, build_gradient_handler
|
||||
|
||||
__all__ = [
|
||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
|
||||
'build_transform', 'build_data_sampler', 'build_gradient_handler',
|
||||
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
|
||||
'build_gradient_handler', 'build_from_config', 'build_from_registry'
|
||||
]
|
||||
|
@@ -2,7 +2,6 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
|
||||
from colossalai.registry import *
|
||||
|
||||
@@ -64,84 +63,6 @@ def build_from_registry(config, registry: Registry):
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
def build_layer(config):
|
||||
"""Returns a layer object of :class:`nn.Module` constructed from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``LAYERS``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`torch.nn.Module`
|
||||
"""
|
||||
return build_from_registry(config, LAYERS)
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
|
||||
from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``LOSSES``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`torch.nn.modules.loss._Loss`
|
||||
"""
|
||||
return build_from_registry(config, LOSSES)
|
||||
|
||||
|
||||
def build_model(config):
|
||||
"""Returns a model object of :class:`nn.Module` constructed from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``MODELS``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`torch.nn.Module`
|
||||
"""
|
||||
return build_from_registry(config, MODELS)
|
||||
|
||||
|
||||
def build_dataset(config):
|
||||
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
|
||||
from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``DATASETS``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`torch.utils.data.Dataset`
|
||||
"""
|
||||
return build_from_registry(config, DATASETS)
|
||||
|
||||
|
||||
def build_optimizer(config, model):
|
||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
|
||||
'model' and 'params'.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``OPTIMIZERS``.
|
||||
model (:class:`nn.Module`): A model containing parameters for the optimizer
|
||||
|
||||
Returns:
|
||||
An object of :class:`torch.optim.Optimizer`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['params'] = model.parameters()
|
||||
return build_from_registry(config_, OPTIMIZERS)
|
||||
|
||||
|
||||
def build_gradient_handler(config, model, optimizer):
|
||||
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
|
||||
`model` and `optimizer`.
|
||||
@@ -160,100 +81,3 @@ def build_gradient_handler(config, model, optimizer):
|
||||
config_['model'] = model
|
||||
config_['optimizer'] = optimizer
|
||||
return build_from_registry(config_, GRADIENT_HANDLER)
|
||||
|
||||
|
||||
def build_hooks(config, trainer):
|
||||
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``HOOKS``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`colossalai.trainer.hooks.BaseHook`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['trainer'] = trainer
|
||||
return build_from_registry(config_, HOOKS)
|
||||
|
||||
|
||||
def build_ophooks(config):
|
||||
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``OPHOOKS``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`colossalai.trainer.hooks.BaseOpHook`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
return build_from_registry(config_, OPHOOKS)
|
||||
|
||||
|
||||
def build_transform(config):
|
||||
"""Returns a transformation object of :class:`torchvision.transforms` constructed
|
||||
from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``TRANSFORMS``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`torchvision.transforms`
|
||||
"""
|
||||
return build_from_registry(config, TRANSFORMS)
|
||||
|
||||
|
||||
def build_data_sampler(config, dataset):
|
||||
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
constructed from `config`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``DATA_SAMPLERS``.
|
||||
dataset (:class:`torch.utils.data.Dataset`): An object of
|
||||
:class:`torch.utils.data.Dataset` containing information
|
||||
used in the construction of the return object
|
||||
Returns:
|
||||
An object of :class:`colossalai.utils.data_sampler.BaseSampler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['dataset'] = dataset
|
||||
return build_from_registry(config_, DATA_SAMPLERS)
|
||||
|
||||
|
||||
def build_lr_scheduler(config, optimizer):
|
||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``lr_schedule``.
|
||||
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing
|
||||
parameters for the learning rate scheduler.
|
||||
|
||||
Returns:
|
||||
An object of :class:`torch.optim.lr_scheduler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
config_['optimizer'] = optimizer
|
||||
return build_from_registry(config_, LR_SCHEDULERS)
|
||||
|
||||
def build_schedule(config):
|
||||
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
|
||||
|
||||
Args:
|
||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
||||
a :class:`colossalai.context.Config` object containing information
|
||||
used in the construction of the ``Schedule``.
|
||||
|
||||
Returns:
|
||||
An object of :class:`colossalai.engine.schedule.BaseSchedule`
|
||||
"""
|
||||
return build_from_registry(config, SCHEDULE)
|
||||
|
@@ -1,266 +0,0 @@
|
||||
import copy
|
||||
import heapq
|
||||
|
||||
from colossalai.builder import build_model, build_layer
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _binary_partition(weights, st, ed):
|
||||
"""Returns the binary partition position of `weights`, given the start
|
||||
position `st` and the end position `ed`.
|
||||
|
||||
Args:
|
||||
weights (list): A python list to be binary partitioned
|
||||
st (int): the start position of the binary partition
|
||||
ed (int): the end position of the binary partition
|
||||
|
||||
Returns:
|
||||
int: the binary partition position of `weights`
|
||||
"""
|
||||
w_sum = weights[ed - 1]
|
||||
prefix = 0
|
||||
if st > 0:
|
||||
w_sum -= weights[st - 1]
|
||||
prefix = weights[st - 1]
|
||||
minimum = float("inf")
|
||||
for idx in range(st + 1, ed):
|
||||
front = weights[idx - 1] - prefix
|
||||
diff = abs(w_sum - 2 * front)
|
||||
if diff < minimum:
|
||||
pos = idx
|
||||
minimum = diff
|
||||
|
||||
return st, pos, ed
|
||||
|
||||
|
||||
def _heap_addition(weights, intervals, add_cnt):
|
||||
"""
|
||||
"""
|
||||
|
||||
def _heap_push(heap, st, ed):
|
||||
value = weights[ed - 1]
|
||||
if st > 0:
|
||||
value -= weights[st - 1]
|
||||
heapq.heappush(heap, (-value, st, ed))
|
||||
|
||||
ret_intervals = []
|
||||
heap = []
|
||||
|
||||
for st, ed in intervals:
|
||||
_heap_push(heap, st, ed)
|
||||
|
||||
while add_cnt > 0:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
if ed - st == 1:
|
||||
ret_intervals.append((st, ed))
|
||||
else:
|
||||
l, m, r = _binary_partition(weights, st, ed)
|
||||
_heap_push(heap, l, m)
|
||||
_heap_push(heap, m, r)
|
||||
add_cnt -= 1
|
||||
|
||||
while heap:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
ret_intervals.append((st, ed))
|
||||
|
||||
ret_intervals.sort()
|
||||
return ret_intervals
|
||||
|
||||
|
||||
def _calc_partitions(weights, value):
|
||||
prev = 0
|
||||
prefix = 0
|
||||
num_block = 0
|
||||
intervals = []
|
||||
|
||||
for idx, w in enumerate(weights):
|
||||
if weights[idx] - prefix > value:
|
||||
intervals.append((prev, idx))
|
||||
prev = idx
|
||||
prefix = weights[idx - 1]
|
||||
num_block += 1
|
||||
|
||||
intervals.append((prev, len(weights)))
|
||||
return num_block + 1, intervals
|
||||
|
||||
|
||||
def _binary_search(weights, num):
|
||||
length = len(weights)
|
||||
prefix = [1 if w == 0 else w for w in weights]
|
||||
for i in range(1, length):
|
||||
prefix[i] += prefix[i - 1]
|
||||
|
||||
lower_bound = max(weights)
|
||||
upper_bound = prefix[length - 1]
|
||||
|
||||
while upper_bound > lower_bound:
|
||||
mid = (upper_bound + lower_bound) // 2
|
||||
number, _ = _calc_partitions(prefix, mid)
|
||||
if number <= num:
|
||||
upper_bound = mid
|
||||
else:
|
||||
lower_bound = mid + 1
|
||||
|
||||
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||||
if num_block < num:
|
||||
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||
|
||||
logger = get_dist_logger()
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
partition_items = num_items // num_chunks
|
||||
for idx in range(num_chunks):
|
||||
base_idx = idx * partition_items
|
||||
chunk_size = partition_items // pipeline_parallel_size
|
||||
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||||
if chunk_size == 0:
|
||||
logger.warning("Some nodes in Pipeline have no requests")
|
||||
|
||||
for p in range(pipeline_parallel_size):
|
||||
st = base_idx
|
||||
base_idx += chunk_size + (p >= left)
|
||||
parts[p].append((st, base_idx))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||
num_total = pipeline_parallel_size * num_chunks
|
||||
num_items = len(weights)
|
||||
if num_items <= num_total:
|
||||
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||||
|
||||
intervals = _binary_search(weights, num_total)
|
||||
|
||||
current = 0
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
for inter in intervals:
|
||||
parts[current].append(inter)
|
||||
current = (current + 1) % pipeline_parallel_size
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def count_layer_params(layers):
|
||||
"""Count the number of parameters in each layer
|
||||
"""
|
||||
param_counts = [0] * len(layers)
|
||||
for idx, cfg in enumerate(layers):
|
||||
layer = build_layer(cfg)
|
||||
params = filter(lambda p: p.requires_grad, layer.parameters())
|
||||
param_counts[idx] = sum(p.numel() for p in params)
|
||||
|
||||
return param_counts
|
||||
|
||||
|
||||
def build_pipeline_model_from_cfg(config,
|
||||
num_chunks: int = 1,
|
||||
partition_method: str = 'parameter',
|
||||
verbose: bool = False):
|
||||
"""An initializer to split the model into different stages for pipeline parallelism.
|
||||
|
||||
An example for the model config is shown below. The class VisionTransformerFromConfig should
|
||||
inherit colossalai.nn.model.ModelFromConfig to allow this initializer to build model from a sequence
|
||||
of layer configurations.
|
||||
|
||||
::
|
||||
|
||||
model_config = dict(
|
||||
type='VisionTransformerFromConfig',
|
||||
embedding_cfg=dict(...),
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
config (dict): Configuration of the model.
|
||||
num_chunks (int, optional): The number of chunks you want to have on the current stage.
|
||||
This value should be 1 in most cases unless you are using virtual pipeline parallelism.
|
||||
partition_method (str, optional): This parameter determines how you want to split your model
|
||||
layers into stages, you can set it as 'layer' or 'parameter'.
|
||||
verbose (bool, optional): Whether to print the logs.
|
||||
"""
|
||||
ori_model = build_model(config)
|
||||
layers = ori_model.layers_cfg
|
||||
layer_length = len(layers)
|
||||
logger = get_dist_logger()
|
||||
if verbose:
|
||||
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||||
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
method = partition_method.lower()
|
||||
# Make a partition
|
||||
if method == 'layer':
|
||||
num_layers = len(layers)
|
||||
parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
|
||||
elif method == 'parameter':
|
||||
param_counts = count_layer_params(layers)
|
||||
# print_rank_0(param_counts)
|
||||
parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
|
||||
else:
|
||||
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
||||
|
||||
# Display the partition
|
||||
if verbose:
|
||||
log_str = 'Layer allocation after partitioning: \n'
|
||||
for stage in range(pipeline_parallel_size):
|
||||
|
||||
num_layers = 0
|
||||
for st, ed in parts[stage]:
|
||||
num_layers += ed - st
|
||||
|
||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
||||
for st, ed in parts[stage]:
|
||||
for idx, layer in enumerate(layers[st:ed]):
|
||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||||
logger.info(log_str, ranks=[0])
|
||||
|
||||
# Save the partition
|
||||
interval = parts[pipeline_rank]
|
||||
|
||||
models = []
|
||||
for st, ed in interval:
|
||||
model = copy.deepcopy(ori_model)
|
||||
model.build_from_cfg(st, ed)
|
||||
models.append(model)
|
||||
|
||||
return nn.ModuleList(models) if len(models) > 1 else models[0]
|
||||
|
||||
|
||||
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
|
||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||
Note that `layer` must be `torch.nn.Sequential`.
|
||||
Args:
|
||||
layers (`torch.nn.Sequential`): Layers of model
|
||||
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
|
||||
in most cases unless you are using virtual pipeline parallelism.
|
||||
verbose (bool, optional): Whether to print the logs.
|
||||
"""
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
||||
module_list = []
|
||||
for start, end in partitions[pipeline_rank]:
|
||||
module_list.append(
|
||||
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
|
||||
*[nn.Identity() for _ in range(len(layers) - end)]))
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
||||
for rank, part in enumerate(partitions):
|
||||
log_str = f'===== stage={rank} =====\n'
|
||||
for chunk, (start, end) in enumerate(part):
|
||||
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
|
||||
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
||||
logger.info(log_str, ranks=[0])
|
||||
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
Reference in New Issue
Block a user