add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)

This commit is contained in:
ver217
2021-12-20 23:26:19 +08:00
committed by GitHub
parent 91c327cb44
commit 8f02a88db2
17 changed files with 544 additions and 170 deletions

View File

@@ -1,10 +1,10 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .pipeline import PipelineModelInitializer
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'PipelineModelInitializer'
'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg'
]

View File

@@ -1,11 +1,12 @@
import copy
import heapq
from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import set_to_cuda
import torch.nn as nn
def _binary_partition(weights, st, ed):
@@ -150,7 +151,19 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts
class PipelineModelInitializer():
def _count_layer_params(layers):
"""Count the number of parameters in each layer
"""
param_counts = [0] * len(layers)
for idx, cfg in enumerate(layers):
layer = build_layer(cfg)
params = filter(lambda p: p.requires_grad, layer.parameters())
param_counts[idx] = sum(p.numel() for p in params)
return param_counts
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False):
"""An intializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should
@@ -168,88 +181,86 @@ class PipelineModelInitializer():
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism.
:type num_chunks: int
:param partition_method: this parameter determines how you want to split your model layers into stages,
you can set it as 'layer' or 'parameter'
:type partition_method: str
:param verbose: whether to print the logs
:type verbose: bool
"""
ori_model = build_model(config)
layers = ori_model.layers_cfg
layer_length = len(layers)
logger = get_dist_logger()
if verbose:
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks
self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers)
self.verbose = verbose
self._logger = get_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
def initialize(self, partition_method='parameter'):
"""Initialize the model object from the config passed
method = partition_method.lower()
# Make a partition
if method == 'layer':
num_layers = len(layers)
parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
elif method == 'parameter':
param_counts = _count_layer_params(layers)
# print_rank_0(param_counts)
parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
else:
raise ValueError("Method should be a pre-set string in [layer, parameter]")
:param partition_method: this parameter determines how you want to split your model layers into stages,
you can set it as 'layer' or 'parameter'
:type partition_method: str
"""
# Some space for initializing comunication groups
self._interval = None
self._partition_layers(method=partition_method)
models = self._build()
model = set_to_cuda(models)
# Display the partition
if verbose:
log_str = 'Layer allocation after partitioning: \n'
for stage in range(pipeline_parallel_size):
return model
num_layers = 0
for st, ed in parts[stage]:
num_layers += ed - st
def _partition_layers(self, method):
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in parts[stage]:
for idx, layer in enumerate(layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
logger.info(log_str, ranks=[0])
method = method.lower()
# Make a partition
if method == 'layer':
num_layers = len(self.layers)
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks)
elif method == 'parameter':
param_counts = self._count_layer_params()
# print_rank_0(param_counts)
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks)
else:
raise ValueError("Method should be a pre-set string in [layer, parameter]")
# Save the partition
interval = parts[pipeline_rank]
# Display the partition
if gpc.get_global_rank() == 0 and self.verbose:
log_str = 'Layer allocation after partitioning: \n'
for stage in range(pipeline_parallel_size):
models = []
for st, ed in interval:
model = copy.deepcopy(ori_model)
model.build_from_cfg(st, ed)
models.append(model)
num_layers = 0
for st, ed in self.parts[stage]:
num_layers += ed - st
return nn.ModuleList(models) if len(models) > 1 else models[0]
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str, ranks=[0])
# Save the partition
self._interval = self.parts[pipeline_rank]
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
"""An intializer to split the model into different stages for pipeline parallelism.
Note that `layer` must be `torch.nn.Sequential`.
def _build(self):
"""Build model from the layer cfg according to the partition
"""
models = []
for st, ed in self._interval:
model = copy.copy(self.ori_model)
model.build_from_cfg(st, ed)
models.append(model)
return models
def _count_layer_params(self):
"""Count the number of parameters in each layer
"""
param_counts = [0] * len(self.layers)
for idx, cfg in enumerate(self.layers):
layer = build_layer(cfg)
params = filter(lambda p: p.requires_grad, layer.parameters())
param_counts[idx] = sum(p.numel() for p in params)
return param_counts
:param layers: layers of model
:type config: `torch.nn.Sequential`
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism.
:type num_chunks: int
:param verbose: whether to print the logs
:type verbose: bool
"""
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(nn.Sequential(*layers[start:end]))
if verbose:
logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0])
for rank, part in enumerate(partitions):
log_str = f'===== stage={rank} =====\n'
for chunk, (start, end) in enumerate(part):
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]