mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
Migrated project
This commit is contained in:
2
colossalai/builder/__init__.py
Normal file
2
colossalai/builder/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .builder import *
|
||||
from .pipeline import ModelInitializer
|
262
colossalai/builder/builder.py
Normal file
262
colossalai/builder/builder.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
|
||||
from colossalai.registry import *
|
||||
|
||||
|
||||
def build_from_config(module, config: dict):
|
||||
"""Returns an object of :class:`module` constructed from `config`.
|
||||
|
||||
:param module: A python or user-defined class
|
||||
:type module: class
|
||||
:param config: A python dict containing information used in the construction
|
||||
of the return object
|
||||
:type config: dict
|
||||
:raises AssertionError: Raises an AssertionError if `module` is not a class
|
||||
:return: An object of :class:`module`
|
||||
:rtype: :class:`module`
|
||||
"""
|
||||
assert inspect.isclass(module), 'module must be a class'
|
||||
return module(**config)
|
||||
|
||||
|
||||
def build_from_registry(config, registry: Registry):
|
||||
"""Returns an object constructed from `config`, the type of the object
|
||||
is specified by `registry`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.colossalai.context.Config`
|
||||
:param registry: A registry specifying the type of the return object
|
||||
:type registry: :class:`Registry`
|
||||
:raises AssertionError: Raises an AssertionError if `registry` is not an object
|
||||
of :class:`Registry` or `mod_type` in `config` is not found in `registry`
|
||||
:raises Exception: Raises an Exception if an error occurred when building
|
||||
from registry
|
||||
:return: An object specified by `registry`
|
||||
:rtype: Python object specified by `registry`
|
||||
"""
|
||||
config_ = config.copy() # keep the original config untouched
|
||||
assert isinstance(
|
||||
registry, Registry), f'Expected type Registry but got {type(registry)}'
|
||||
|
||||
mod_type = config_.pop('type')
|
||||
assert registry.has(
|
||||
mod_type), f'{mod_type} is not found in registry {registry.name}'
|
||||
try:
|
||||
obj = registry.get_module(mod_type)(**config_)
|
||||
except Exception as e:
|
||||
print(
|
||||
f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
|
||||
raise e
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def build_layer(config):
|
||||
"""Returns a layer object of :class:`nn.Module` constructed from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`nn.Module`
|
||||
:rtype: :class:`nn.Module`
|
||||
"""
|
||||
return build_from_registry(config, LAYERS)
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
|
||||
from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`torch.autograd.Function`
|
||||
:rtype: :class:`torch.autograd.Function`
|
||||
"""
|
||||
return build_from_registry(config, LOSSES)
|
||||
|
||||
|
||||
def build_model(config):
|
||||
"""Returns a model object of :class:`nn.Module` constructed from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`nn.Module`
|
||||
:rtype: :class:`nn.Module`
|
||||
"""
|
||||
return build_from_registry(config, MODELS)
|
||||
|
||||
|
||||
def build_dataset(config):
|
||||
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
|
||||
from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`torch.utils.data.Dataset`
|
||||
:rtype: :class:`torch.utils.data.Dataset`
|
||||
"""
|
||||
return build_from_registry(config, DATASETS)
|
||||
|
||||
|
||||
def build_optimizer(config, model, params: Iterable = None, need_module=False):
|
||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
|
||||
'model' and 'params'.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param model: A model containing parameters for the optimizer
|
||||
:type model: :class:`nn.Module`
|
||||
:param params: A dict containing parameters for the optimizer
|
||||
:type params: dict, optional
|
||||
:param need_module: Indicates whether the optimizer needs a module
|
||||
:type params: bool, optional
|
||||
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None
|
||||
:return: An object of :class:`torch.optim.Optimizer`
|
||||
:rtype: :class:`torch.optim.Optimizer`
|
||||
"""
|
||||
assert model is not None or params is not None, 'arguments model and params can not both be None'
|
||||
if need_module:
|
||||
config['module'] = model
|
||||
elif model is not None:
|
||||
config['params'] = model.parameters()
|
||||
elif params is not None:
|
||||
config['params'] = params
|
||||
|
||||
return build_from_registry(config, OPTIMIZERS)
|
||||
|
||||
|
||||
def build_gradient_handler(config, model, optimizer):
|
||||
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
|
||||
`model` and `optimizer`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param model: A model containing parameters for the gradient handler
|
||||
:type model: :class:`nn.Module`
|
||||
:param optimizer: An optimizer object containing parameters for the gradient handler
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:return: An object of :class:`BaseGradientHandler`
|
||||
:rtype: :class:`BaseGradientHandler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_)
|
||||
|
||||
|
||||
def build_hooks(config, trainer):
|
||||
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param trainer: A :class:`Trainer` object containing parameters for the hook
|
||||
:type trainer: :class:`Trainer`
|
||||
:return: An object of :class:`BaseHook`
|
||||
:rtype: :class:`BaseHook`
|
||||
"""
|
||||
config['trainer'] = trainer
|
||||
return build_from_registry(config, HOOKS)
|
||||
|
||||
|
||||
def build_transform(config):
|
||||
"""Returns a transformation object of :class:`torchvision.transforms` constructed
|
||||
from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: An object of :class:`torchvision.transforms`
|
||||
:rtype: :class:`torchvision.transforms`
|
||||
"""
|
||||
return build_from_registry(config, TRANSFORMS)
|
||||
|
||||
|
||||
def build_pipe_alloc_policy(config):
|
||||
"""Returns a pipeline allocation policy object constructed from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:return: A pipeline allocation policy object
|
||||
:rtype:
|
||||
"""
|
||||
return build_from_registry(config, PIPE_ALLOC_POLICY)
|
||||
|
||||
|
||||
def build_data_sampler(config, dataset):
|
||||
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
constructed from `config`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param dataset: An object of :class:`torch.utils.data.Dataset` containing information
|
||||
used in the construction of the return object
|
||||
:type dataset: :class:`torch.utils.data.Dataset`
|
||||
:return: An object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
return SAMPLERS.get_module(mod_type)(dataset, **config_)
|
||||
|
||||
|
||||
def build_optimizer_wrapper(config, optimizer, model=None):
|
||||
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
|
||||
from `config`, `model` and `optimizer`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param optimizer: An optimizer object containing parameters for the gradient handler
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:param model: A model containing parameters for the gradient handler
|
||||
:type model: :class:`nn.Module`, optional
|
||||
:return: An object of :class:`torch.optim.Optimizer`
|
||||
:rtype: :class:`torch.optim.Optimizer`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
|
||||
# LSG: special treatment for zeor level 3
|
||||
if mod_type == 'ZeroRedundancyOptimizer_Level_3':
|
||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
|
||||
else:
|
||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
||||
|
||||
|
||||
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||
|
||||
:param config: A python dict or a :class:`colossalai.context.Config` object
|
||||
containing information used in the construction of the return object
|
||||
:type config: dict or :class:`colossalai.context.Config`
|
||||
:param optimizer: An optimizer object containing parameters for the learning rate
|
||||
scheduler
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:param total_steps: Number of total steps of the learning rate scheduler
|
||||
:type total_steps: int
|
||||
:param num_steps_per_epoch: number of steps per epoch of the learning rate scheduler
|
||||
:type num_steps_per_epoch: int
|
||||
:return: An object of :class:`torch.optim.lr_scheduler`
|
||||
:rtype: :class:`torch.optim.lr_scheduler`
|
||||
"""
|
||||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
# warmup epochs will overwrite warmup steps
|
||||
if 'warmup_epochs' in config_:
|
||||
warmup_epochs = config_.pop('warmup_epochs')
|
||||
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
|
||||
**config_)
|
226
colossalai/builder/pipeline.py
Normal file
226
colossalai/builder/pipeline.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import copy
|
||||
import heapq
|
||||
|
||||
from colossalai.builder import build_model, build_layer
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import set_to_cuda
|
||||
|
||||
|
||||
def _binary_partition(weights, st, ed):
|
||||
"""Returns the binary partition position of `weights`, given the start
|
||||
position `st` and the end position `ed`.
|
||||
|
||||
:param weights: A python list to be binary partitioned
|
||||
:type weights: list
|
||||
:param st: the start position of the binary partition
|
||||
:type st: int
|
||||
:param ed: the end postition of the binary partition
|
||||
:type ed: int
|
||||
:return: the binary partition position of `weights`
|
||||
:rtype: int
|
||||
"""
|
||||
w_sum = weights[ed - 1]
|
||||
prefix = 0
|
||||
if st > 0:
|
||||
w_sum -= weights[st - 1]
|
||||
prefix = weights[st - 1]
|
||||
minimum = float("inf")
|
||||
for idx in range(st + 1, ed):
|
||||
front = weights[idx - 1] - prefix
|
||||
diff = abs(w_sum - 2 * front)
|
||||
if diff < minimum:
|
||||
pos = idx
|
||||
minimum = diff
|
||||
|
||||
return st, pos, ed
|
||||
|
||||
|
||||
def _heap_addition(weights, intervals, add_cnt):
|
||||
"""
|
||||
"""
|
||||
def _heap_push(heap, st, ed):
|
||||
value = weights[ed - 1]
|
||||
if st > 0:
|
||||
value -= weights[st - 1]
|
||||
heapq.heappush(heap, (-value, st, ed))
|
||||
|
||||
ret_intervals = []
|
||||
heap = []
|
||||
|
||||
for st, ed in intervals:
|
||||
_heap_push(heap, st, ed)
|
||||
|
||||
while add_cnt > 0:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
if ed - st == 1:
|
||||
ret_intervals.append((st, ed))
|
||||
else:
|
||||
l, m, r = _binary_partition(weights, st, ed)
|
||||
_heap_push(heap, l, m)
|
||||
_heap_push(heap, m, r)
|
||||
add_cnt -= 1
|
||||
|
||||
while heap:
|
||||
_, st, ed = heapq.heappop(heap)
|
||||
ret_intervals.append((st, ed))
|
||||
|
||||
ret_intervals.sort()
|
||||
return ret_intervals
|
||||
|
||||
|
||||
def _calc_partitions(weights, value):
|
||||
prev = 0
|
||||
prefix = 0
|
||||
num_block = 0
|
||||
intervals = []
|
||||
|
||||
for idx, w in enumerate(weights):
|
||||
if weights[idx] - prefix > value:
|
||||
intervals.append((prev, idx))
|
||||
prev = idx
|
||||
prefix = weights[idx - 1]
|
||||
num_block += 1
|
||||
|
||||
intervals.append((prev, len(weights)))
|
||||
return num_block + 1, intervals
|
||||
|
||||
|
||||
def _binary_search(weights, num):
|
||||
length = len(weights)
|
||||
prefix = [1 if w == 0 else w for w in weights]
|
||||
for i in range(1, length):
|
||||
prefix[i] += prefix[i - 1]
|
||||
|
||||
lower_bound = max(weights)
|
||||
upper_bound = prefix[length - 1]
|
||||
|
||||
while upper_bound > lower_bound:
|
||||
mid = (upper_bound + lower_bound) // 2
|
||||
number, _ = _calc_partitions(prefix, mid)
|
||||
if number <= num:
|
||||
upper_bound = mid
|
||||
else:
|
||||
lower_bound = mid + 1
|
||||
|
||||
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||||
if num_block < num:
|
||||
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
def _partition_uniform(num_items, num_parts, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||
|
||||
logger = get_global_dist_logger()
|
||||
parts = [[] for _ in range(num_parts)]
|
||||
partition_items = num_items // num_chunks
|
||||
for idx in range(num_chunks):
|
||||
base_idx = idx * partition_items
|
||||
chunk_size = partition_items // num_parts
|
||||
left = num_parts - partition_items % num_parts
|
||||
if chunk_size == 0:
|
||||
logger.warning("Some nodes in Pipeline have no requests")
|
||||
|
||||
for p in range(num_parts):
|
||||
st = base_idx
|
||||
base_idx += chunk_size + (p >= left)
|
||||
parts[p].append((st, base_idx))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _partition_balanced(weights, num_parts, num_chunks):
|
||||
num_total = num_parts * num_chunks
|
||||
num_items = len(weights)
|
||||
if num_items <= num_total:
|
||||
return _partition_uniform(num_items, num_parts, num_chunks)
|
||||
|
||||
intervals = _binary_search(weights, num_total)
|
||||
|
||||
current = 0
|
||||
parts = [[] for _ in range(num_parts)]
|
||||
for inter in intervals:
|
||||
parts[current].append(inter)
|
||||
current = (current + 1) % num_parts
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
class ModelInitializer():
|
||||
def __init__(self, config, num_chunks, verbose=False):
|
||||
self.num_chunks = num_chunks
|
||||
self.ori_model = build_model(config)
|
||||
self.layers = self.ori_model.layers_cfg
|
||||
layer_length = len(self.layers)
|
||||
self.verbose = verbose
|
||||
self._logger = get_global_dist_logger()
|
||||
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||||
|
||||
def model_initialize(self, partition_method='parameter'):
|
||||
# Some space for initializing comunication groups
|
||||
self._interval = None
|
||||
self._partition_layers(method=partition_method)
|
||||
models = self._build()
|
||||
model = set_to_cuda(models)
|
||||
|
||||
return model
|
||||
|
||||
def _partition_layers(self, method):
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
method = method.lower()
|
||||
# Make a partition
|
||||
if method == 'layer':
|
||||
num_layers = len(self.layers)
|
||||
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks)
|
||||
elif method == 'parameter':
|
||||
param_counts = self._count_layer_params()
|
||||
# print_rank_0(param_counts)
|
||||
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks)
|
||||
else:
|
||||
assert method == 'layer', "Method should be a pre-set string"
|
||||
|
||||
# Display the partition
|
||||
if gpc.get_global_rank() == 0 and self.verbose:
|
||||
log_str = 'Layer allocation after partitioning: \n'
|
||||
for stage in range(pipeline_parallel_size):
|
||||
|
||||
num_layers = 0
|
||||
for st, ed in self.parts[stage]:
|
||||
num_layers += ed - st
|
||||
|
||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
||||
for st, ed in self.parts[stage]:
|
||||
for idx, layer in enumerate(self.layers[st: ed]):
|
||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||||
self._logger.info(log_str)
|
||||
|
||||
# Save the partition
|
||||
self._interval = self.parts[pipeline_rank]
|
||||
|
||||
def _build(self):
|
||||
"""Build model from the layer cfg according to the partition
|
||||
"""
|
||||
models = []
|
||||
for st, ed in self._interval:
|
||||
model = copy.copy(self.ori_model)
|
||||
model.build_from_cfg(st, ed)
|
||||
models.append(model)
|
||||
|
||||
return models
|
||||
|
||||
def _count_layer_params(self):
|
||||
"""Count the number of parameters in each layer
|
||||
"""
|
||||
param_counts = [0] * len(self.layers)
|
||||
for idx, cfg in enumerate(self.layers):
|
||||
layer = build_layer(cfg)
|
||||
params = filter(lambda p: p.requires_grad, layer.parameters())
|
||||
param_counts[idx] = sum(p.numel() for p in params)
|
||||
|
||||
return param_counts
|
Reference in New Issue
Block a user