mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 19:55:03 +00:00
add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
|
||||
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
|
||||
build_gradient_handler)
|
||||
from .pipeline import PipelineModelInitializer
|
||||
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
|
||||
|
||||
__all__ = [
|
||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
|
||||
'build_gradient_handler', 'PipelineModelInitializer'
|
||||
'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg'
|
||||
]
|
||||
|
@@ -1,11 +1,12 @@
|
||||
import copy
|
||||
import heapq
|
||||
|
||||
|
||||
from colossalai.builder import build_model, build_layer
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import set_to_cuda
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _binary_partition(weights, st, ed):
|
||||
@@ -150,7 +151,19 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||
return parts
|
||||
|
||||
|
||||
class PipelineModelInitializer():
|
||||
def _count_layer_params(layers):
|
||||
"""Count the number of parameters in each layer
|
||||
"""
|
||||
param_counts = [0] * len(layers)
|
||||
for idx, cfg in enumerate(layers):
|
||||
layer = build_layer(cfg)
|
||||
params = filter(lambda p: p.requires_grad, layer.parameters())
|
||||
param_counts[idx] = sum(p.numel() for p in params)
|
||||
|
||||
return param_counts
|
||||
|
||||
|
||||
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False):
|
||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||
|
||||
An example for the model config is shown below. The class VisionTransformerFromConfig should
|
||||
@@ -168,88 +181,86 @@ class PipelineModelInitializer():
|
||||
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
||||
in most cases unless you are using virutal pipeline parallelism.
|
||||
:type num_chunks: int
|
||||
:param partition_method: this parameter determines how you want to split your model layers into stages,
|
||||
you can set it as 'layer' or 'parameter'
|
||||
:type partition_method: str
|
||||
:param verbose: whether to print the logs
|
||||
:type verbose: bool
|
||||
|
||||
"""
|
||||
ori_model = build_model(config)
|
||||
layers = ori_model.layers_cfg
|
||||
layer_length = len(layers)
|
||||
logger = get_dist_logger()
|
||||
if verbose:
|
||||
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||||
|
||||
def __init__(self, config, num_chunks, verbose=False):
|
||||
self.num_chunks = num_chunks
|
||||
self.ori_model = build_model(config)
|
||||
self.layers = self.ori_model.layers_cfg
|
||||
layer_length = len(self.layers)
|
||||
self.verbose = verbose
|
||||
self._logger = get_dist_logger()
|
||||
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
def initialize(self, partition_method='parameter'):
|
||||
"""Initialize the model object from the config passed
|
||||
method = partition_method.lower()
|
||||
# Make a partition
|
||||
if method == 'layer':
|
||||
num_layers = len(layers)
|
||||
parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
|
||||
elif method == 'parameter':
|
||||
param_counts = _count_layer_params(layers)
|
||||
# print_rank_0(param_counts)
|
||||
parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
|
||||
else:
|
||||
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
||||
|
||||
:param partition_method: this parameter determines how you want to split your model layers into stages,
|
||||
you can set it as 'layer' or 'parameter'
|
||||
:type partition_method: str
|
||||
|
||||
"""
|
||||
# Some space for initializing comunication groups
|
||||
self._interval = None
|
||||
self._partition_layers(method=partition_method)
|
||||
models = self._build()
|
||||
model = set_to_cuda(models)
|
||||
# Display the partition
|
||||
if verbose:
|
||||
log_str = 'Layer allocation after partitioning: \n'
|
||||
for stage in range(pipeline_parallel_size):
|
||||
|
||||
return model
|
||||
num_layers = 0
|
||||
for st, ed in parts[stage]:
|
||||
num_layers += ed - st
|
||||
|
||||
def _partition_layers(self, method):
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
||||
for st, ed in parts[stage]:
|
||||
for idx, layer in enumerate(layers[st: ed]):
|
||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||||
logger.info(log_str, ranks=[0])
|
||||
|
||||
method = method.lower()
|
||||
# Make a partition
|
||||
if method == 'layer':
|
||||
num_layers = len(self.layers)
|
||||
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks)
|
||||
elif method == 'parameter':
|
||||
param_counts = self._count_layer_params()
|
||||
# print_rank_0(param_counts)
|
||||
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks)
|
||||
else:
|
||||
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
||||
# Save the partition
|
||||
interval = parts[pipeline_rank]
|
||||
|
||||
# Display the partition
|
||||
if gpc.get_global_rank() == 0 and self.verbose:
|
||||
log_str = 'Layer allocation after partitioning: \n'
|
||||
for stage in range(pipeline_parallel_size):
|
||||
models = []
|
||||
for st, ed in interval:
|
||||
model = copy.deepcopy(ori_model)
|
||||
model.build_from_cfg(st, ed)
|
||||
models.append(model)
|
||||
|
||||
num_layers = 0
|
||||
for st, ed in self.parts[stage]:
|
||||
num_layers += ed - st
|
||||
return nn.ModuleList(models) if len(models) > 1 else models[0]
|
||||
|
||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
||||
for st, ed in self.parts[stage]:
|
||||
for idx, layer in enumerate(self.layers[st: ed]):
|
||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||||
self._logger.info(log_str, ranks=[0])
|
||||
|
||||
# Save the partition
|
||||
self._interval = self.parts[pipeline_rank]
|
||||
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
|
||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||
Note that `layer` must be `torch.nn.Sequential`.
|
||||
|
||||
def _build(self):
|
||||
"""Build model from the layer cfg according to the partition
|
||||
"""
|
||||
models = []
|
||||
for st, ed in self._interval:
|
||||
model = copy.copy(self.ori_model)
|
||||
model.build_from_cfg(st, ed)
|
||||
models.append(model)
|
||||
|
||||
return models
|
||||
|
||||
def _count_layer_params(self):
|
||||
"""Count the number of parameters in each layer
|
||||
"""
|
||||
param_counts = [0] * len(self.layers)
|
||||
for idx, cfg in enumerate(self.layers):
|
||||
layer = build_layer(cfg)
|
||||
params = filter(lambda p: p.requires_grad, layer.parameters())
|
||||
param_counts[idx] = sum(p.numel() for p in params)
|
||||
|
||||
return param_counts
|
||||
:param layers: layers of model
|
||||
:type config: `torch.nn.Sequential`
|
||||
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
|
||||
in most cases unless you are using virutal pipeline parallelism.
|
||||
:type num_chunks: int
|
||||
:param verbose: whether to print the logs
|
||||
:type verbose: bool
|
||||
"""
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
||||
module_list = []
|
||||
for start, end in partitions[pipeline_rank]:
|
||||
module_list.append(nn.Sequential(*layers[start:end]))
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
||||
for rank, part in enumerate(partitions):
|
||||
log_str = f'===== stage={rank} =====\n'
|
||||
for chunk, (start, end) in enumerate(part):
|
||||
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
|
||||
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
||||
logger.info(log_str, ranks=[0])
|
||||
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
||||
|
Reference in New Issue
Block a user