add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)

This commit is contained in:
ver217 2021-12-20 23:26:19 +08:00 committed by GitHub
parent 91c327cb44
commit 8f02a88db2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 544 additions and 170 deletions

View File

@ -20,10 +20,16 @@ def convert_to_naive_amp(model: nn.Module,
:return: (model, optimizer) :return: (model, optimizer)
:rtype: Tuple :rtype: Tuple
""" """
if is_no_pp_or_last_stage(): if isinstance(model, nn.ModuleList):
model = NaiveAMPModel(model, output_to_fp32=True) # interleaved pipeline
module_list = []
for chunk, m in enumerate(model):
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
model = nn.ModuleList(module_list)
else: else:
model = NaiveAMPModel(model, output_to_fp32=False) output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config) optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
return model, optimizer return model, optimizer

View File

@ -14,7 +14,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes, from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier) clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier, is_using_pp)
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
@ -58,7 +58,8 @@ class DynamicGradScaler:
backoff_factor, backoff_factor,
growth_interval, growth_interval,
hysteresis, hysteresis,
max_scale: int = None): max_scale: int = None,
verbose: bool = False):
""""Grad scaler with dynamic scale that gets adjusted """"Grad scaler with dynamic scale that gets adjusted
during training.""" during training."""
assert initial_scale > 0.0 assert initial_scale > 0.0
@ -91,6 +92,7 @@ class DynamicGradScaler:
self._hysteresis_tracker = self.hysteresis self._hysteresis_tracker = self.hysteresis
self._logger = get_dist_logger() self._logger = get_dist_logger()
self.verbose = verbose
@property @property
def scale(self): def scale(self):
@ -111,6 +113,7 @@ class DynamicGradScaler:
if self._hysteresis_tracker <= 0: if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale) self.min_scale)
if self.verbose:
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else: else:
# If there is no nan/inf, increment the growth tracker. # If there is no nan/inf, increment the growth tracker.
@ -122,11 +125,14 @@ class DynamicGradScaler:
self._hysteresis_tracker = self.hysteresis self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale. # and scale up the loss scale.
if self._max_scale is not None and self._scale >= self._max_scale: if self._max_scale is not None and self._scale >= self._max_scale:
if self.verbose:
self._logger.info( self._logger.info(
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
else: else:
self._scale = self._scale * self.growth_factor self._scale = self._scale * self.growth_factor
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) if self.verbose:
self._logger.info(
f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
@ -162,6 +168,8 @@ class FP16Optimizer(Optimizer):
:type hysterisis: int :type hysterisis: int
:param max_scale: maximum loss scale allowed :param max_scale: maximum loss scale allowed
:type max_scale: int :type max_scale: int
:param verbose: if set to `True`, will print debug info
:type verbose: bool
""" """
def __init__(self, def __init__(self,
@ -174,16 +182,18 @@ class FP16Optimizer(Optimizer):
backoff_factor=0.5, backoff_factor=0.5,
growth_interval=1000, growth_interval=1000,
hysteresis=2, hysteresis=2,
max_scale: int = 2 ** 32): max_scale: int = 2 ** 32,
verbose: bool = False):
# default args for compatibility # default args for compatibility
bf16 = False bf16 = False
params_have_main_grad = True params_have_main_grad = False
# have a defaults for compatibility with pytorch optim # have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults self.defaults = optimizer.defaults
# log config # log config
self._logger = get_dist_logger() self._logger = get_dist_logger()
if verbose:
self._logger.info(f"\n========= FP16 Optimizer Config =========\n" self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n" f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n" f"clip_grad = {clip_grad}\n"
@ -212,7 +222,8 @@ class FP16Optimizer(Optimizer):
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
growth_interval=growth_interval, growth_interval=growth_interval,
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale max_scale=max_scale,
verbose=verbose
) )
# None grad scaler is only supported for bf16. # None grad scaler is only supported for bf16.
@ -350,6 +361,11 @@ class FP16Optimizer(Optimizer):
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR)) group=gpc.get_group(ParallelMode.TENSOR))
if is_using_pp():
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PIPELINE))
# Check for nan. # Check for nan.
found_inf_flag = (self.found_inf.item() > 0) found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag return found_inf_flag

View File

@ -1,10 +1,10 @@
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer, from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler) build_gradient_handler)
from .pipeline import PipelineModelInitializer from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
__all__ = [ __all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'PipelineModelInitializer' 'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg'
] ]

View File

@ -1,11 +1,12 @@
import copy import copy
import heapq import heapq
from colossalai.builder import build_model, build_layer from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import set_to_cuda import torch.nn as nn
def _binary_partition(weights, st, ed): def _binary_partition(weights, st, ed):
@ -150,7 +151,19 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts return parts
class PipelineModelInitializer(): def _count_layer_params(layers):
"""Count the number of parameters in each layer
"""
param_counts = [0] * len(layers)
for idx, cfg in enumerate(layers):
layer = build_layer(cfg)
params = filter(lambda p: p.requires_grad, layer.parameters())
param_counts[idx] = sum(p.numel() for p in params)
return param_counts
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False):
"""An intializer to split the model into different stages for pipeline parallelism. """An intializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should An example for the model config is shown below. The class VisionTransformerFromConfig should
@ -168,88 +181,86 @@ class PipelineModelInitializer():
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1 :param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism. in most cases unless you are using virutal pipeline parallelism.
:type num_chunks: int :type num_chunks: int
:param verbose: whether to print the logs
:type verbose: bool
"""
def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks
self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers)
self.verbose = verbose
self._logger = get_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
def initialize(self, partition_method='parameter'):
"""Initialize the model object from the config passed
:param partition_method: this parameter determines how you want to split your model layers into stages, :param partition_method: this parameter determines how you want to split your model layers into stages,
you can set it as 'layer' or 'parameter' you can set it as 'layer' or 'parameter'
:type partition_method: str :type partition_method: str
:param verbose: whether to print the logs
:type verbose: bool
""" """
# Some space for initializing comunication groups ori_model = build_model(config)
self._interval = None layers = ori_model.layers_cfg
self._partition_layers(method=partition_method) layer_length = len(layers)
models = self._build() logger = get_dist_logger()
model = set_to_cuda(models) if verbose:
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
return model
def _partition_layers(self, method):
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
method = method.lower() method = partition_method.lower()
# Make a partition # Make a partition
if method == 'layer': if method == 'layer':
num_layers = len(self.layers) num_layers = len(layers)
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks) parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
elif method == 'parameter': elif method == 'parameter':
param_counts = self._count_layer_params() param_counts = _count_layer_params(layers)
# print_rank_0(param_counts) # print_rank_0(param_counts)
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks) parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
else: else:
raise ValueError("Method should be a pre-set string in [layer, parameter]") raise ValueError("Method should be a pre-set string in [layer, parameter]")
# Display the partition # Display the partition
if gpc.get_global_rank() == 0 and self.verbose: if verbose:
log_str = 'Layer allocation after partitioning: \n' log_str = 'Layer allocation after partitioning: \n'
for stage in range(pipeline_parallel_size): for stage in range(pipeline_parallel_size):
num_layers = 0 num_layers = 0
for st, ed in self.parts[stage]: for st, ed in parts[stage]:
num_layers += ed - st num_layers += ed - st
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in self.parts[stage]: for st, ed in parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]): for idx, layer in enumerate(layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n' log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])
# Save the partition # Save the partition
self._interval = self.parts[pipeline_rank] interval = parts[pipeline_rank]
def _build(self):
"""Build model from the layer cfg according to the partition
"""
models = [] models = []
for st, ed in self._interval: for st, ed in interval:
model = copy.copy(self.ori_model) model = copy.deepcopy(ori_model)
model.build_from_cfg(st, ed) model.build_from_cfg(st, ed)
models.append(model) models.append(model)
return models return nn.ModuleList(models) if len(models) > 1 else models[0]
def _count_layer_params(self):
"""Count the number of parameters in each layer def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
"""An intializer to split the model into different stages for pipeline parallelism.
Note that `layer` must be `torch.nn.Sequential`.
:param layers: layers of model
:type config: `torch.nn.Sequential`
:param num_chunks: the number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism.
:type num_chunks: int
:param verbose: whether to print the logs
:type verbose: bool
""" """
param_counts = [0] * len(self.layers) pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
for idx, cfg in enumerate(self.layers): pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
layer = build_layer(cfg) partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
params = filter(lambda p: p.requires_grad, layer.parameters()) module_list = []
param_counts[idx] = sum(p.numel() for p in params) for start, end in partitions[pipeline_rank]:
module_list.append(nn.Sequential(*layers[start:end]))
return param_counts if verbose:
logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0])
for rank, part in enumerate(partitions):
log_str = f'===== stage={rank} =====\n'
for chunk, (start, end) in enumerate(part):
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]

View File

@ -63,9 +63,6 @@ def _communicate(tensor_send_next=None,
next_rank = gpc.get_next_global_rank( next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE) ParallelMode.PIPELINE)
# rank = dist.get_rank()
rank = gpc.get_global_rank()
ops = [] ops = []
if tensor_send_prev is not None: if tensor_send_prev is not None:
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank) send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
@ -88,7 +85,7 @@ def _communicate(tensor_send_next=None,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None): def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
"""Receives the input tensor from the previous member in pipeline. """Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
@ -98,16 +95,17 @@ def recv_forward(input_tensor_shape, prev_rank=None):
:return: The input tensor in forward step :return: The input tensor in forward step
:rtype: :class:`torch.Tensor` :rtype: :class:`torch.Tensor`
""" """
if gpc.is_first_rank(ParallelMode.PIPELINE): if gpc.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
input_tensor, _ = _communicate(recv_prev=True, input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank) prev_rank=prev_rank,
dtype=dtype)
return input_tensor return input_tensor
def recv_backward(output_grad_shape, next_rank=None): def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
"""Receives the grad tensor from the next member in pipeline. """Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
@ -117,12 +115,13 @@ def recv_backward(output_grad_shape, next_rank=None):
:return: The grad of output tensor in forward step :return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor` :rtype: :class:`torch.Tensor`
""" """
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _communicate(recv_next=True, _, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank) next_rank=next_rank,
dtype=dtype)
return output_tensor_grad return output_tensor_grad
@ -134,7 +133,7 @@ def send_forward(output_tensor, next_rank=None):
:type output_tensor: :class:`torch.Tensor` :type output_tensor: :class:`torch.Tensor`
:type next_rank: int, optional :type next_rank: int, optional
""" """
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor, _communicate(tensor_send_next=output_tensor,
next_rank=next_rank) next_rank=next_rank)
@ -147,7 +146,7 @@ def send_backward(input_tensor_grad, prev_rank=None):
:type input_tensor_grad: :class:`torch.Tensor` :type input_tensor_grad: :class:`torch.Tensor`
:type prev_rank: int, optional :type prev_rank: int, optional
""" """
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad, _communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank) prev_rank=prev_rank)
@ -155,7 +154,8 @@ def send_backward(input_tensor_grad, prev_rank=None):
def send_forward_recv_backward(output_tensor, def send_forward_recv_backward(output_tensor,
output_grad_shape, output_grad_shape,
recv_next=True, recv_next=True,
next_rank=None): next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the next member in pipeline, while recieves the grad tensor from the
next member in pipeline. next member in pipeline.
@ -167,20 +167,22 @@ def send_forward_recv_backward(output_tensor,
:return: The grad of output tensor in forward step :return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor` :rtype: :class:`torch.Tensor`
""" """
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor, _, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank) next_rank=next_rank,
dtype=dtype)
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape, input_tensor_shape,
recv_prev=True, recv_prev=True,
prev_rank=None): prev_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the previous member in pipeline, while recieves the input tensor from the
previous member in pipeline. previous member in pipeline.
@ -192,13 +194,14 @@ def send_backward_recv_forward(input_tensor_grad,
:return: The input tensor in forward step :return: The input tensor in forward step
:rtype: :class:`torch.Tensor` :rtype: :class:`torch.Tensor`
""" """
if gpc.is_first_rank(ParallelMode.PIPELINE): if gpc.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad, input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank) prev_rank=prev_rank,
dtype=dtype)
return input_tensor return input_tensor
@ -206,7 +209,8 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape, input_tensor_shape,
recv_prev=True, recv_prev=True,
prev_rank=None, prev_rank=None,
next_rank=None): next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the next member in pipeline, while recieves the input tensor from the
previous member in pipeline. previous member in pipeline.
@ -222,7 +226,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank) next_rank=next_rank,
dtype=dtype)
return input_tensor return input_tensor
@ -230,7 +235,8 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape, output_grad_shape,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None): next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the previous member in pipeline, while recieves the grad tensor from the
next member in pipeline. next member in pipeline.
@ -246,7 +252,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank) next_rank=next_rank,
dtype=dtype)
return output_tensor_grad return output_tensor_grad
@ -257,7 +264,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True, recv_prev=True,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None): next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the input tensor to the next and """Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous. next and the input tensor from the previous.
@ -281,5 +289,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank) next_rank=next_rank,
dtype=dtype)
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad

View File

@ -29,14 +29,8 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
send_shape = torch.tensor(tensor.size(), **tensor_kwargs) send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
ops = [ dist.send(send_ndims, next_rank)
dist.P2POp(dist.isend, send_ndims, next_rank), dist.send(send_shape, next_rank)
dist.P2POp(dist.isend, send_shape, next_rank)
]
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
return False return False

View File

@ -53,6 +53,8 @@ class ParallelContext:
self.data_parallel_size = 1 self.data_parallel_size = 1
self.pipeline_parallel_size = 1 self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1 self.tensor_parallel_size = 1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
# logging # logging
self._verbose = False self._verbose = False
@ -205,6 +207,18 @@ class ParallelContext:
world_size = self.get_world_size(parallel_mode) world_size = self.get_world_size(parallel_mode)
return rank == world_size - 1 return rank == world_size - 1
def is_pipeline_first_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
return False
return self.is_first_rank(ParallelMode.PIPELINE)
def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False
return self.is_last_rank(ParallelMode.PIPELINE)
def get_world_size(self, parallel_mode: ParallelMode): def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`. """Returns the world size for `parallel_mode`.
@ -494,3 +508,9 @@ class ParallelContext:
self._logger.info( self._logger.info(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
ranks=[0]) ranks=[0])
def set_virtual_pipeline_parallel_size(self, size):
self.virtual_pipeline_parallel_size = size
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank

View File

@ -32,7 +32,7 @@ class DataParallelGradientHandler(BaseGradientHandler):
if tp not in buckets: if tp not in buckets:
buckets[tp] = [] buckets[tp] = []
buckets[tp].append(param) buckets[tp].append(param)
param.main_grad = param.grad # param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads. # For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets: for tp in buckets:

View File

@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
from ._non_pipeline_schedule import NonPipelineSchedule from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule'] __all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule', 'InterleavedPipelineSchedule']

View File

@ -13,9 +13,8 @@ from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from colossalai.amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]): def squeeze(x: Union[Tensor, tuple, list]):
@ -47,6 +46,7 @@ class PipelineSchedule(BaseSchedule):
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.sync_data = sync_data self.sync_data = sync_data
self.dtype = torch.float
def _move_to_device(self, data): def _move_to_device(self, data):
if isinstance(data, ( if isinstance(data, (
@ -122,12 +122,8 @@ class PipelineSchedule(BaseSchedule):
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
) )
# LSG: set default dtype to fp16 for communication
if isinstance(engine.model, NaiveAMPModel): if isinstance(engine.model, NaiveAMPModel):
torch.set_default_dtype(torch.half) self.dtype = torch.half
self.logger.warning(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True): def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
@ -252,7 +248,7 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape) ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape) input_tensor = recv_forward(ft_shape, dtype=self.dtype)
output_tensor = self.forward_step( output_tensor = self.forward_step(
engine, input_tensor, return_tensors, engine, input_tensor, return_tensors,
return_loss=return_loss return_loss=return_loss
@ -272,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape) ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape) input_tensor = recv_forward(ft_shape, dtype=self.dtype)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
@ -286,11 +282,11 @@ class PipelineSchedule(BaseSchedule):
send_forward(output_tensor) send_forward(output_tensor)
if not last_iteration: if not last_iteration:
input_tensor = recv_forward(ft_shape) input_tensor = recv_forward(ft_shape, dtype=self.dtype)
else: else:
output_tensor_grad = send_forward_recv_backward( output_tensor_grad = send_forward_recv_backward(
output_tensor, bt_shape) output_tensor, bt_shape, dtype=self.dtype)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
@ -312,7 +308,7 @@ class PipelineSchedule(BaseSchedule):
send_backward(input_tensor_grad) send_backward(input_tensor_grad)
else: else:
input_tensor = send_backward_recv_forward( input_tensor = send_backward_recv_forward(
input_tensor_grad, ft_shape) input_tensor_grad, ft_shape, dtype=self.dtype)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
@ -320,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(bt_shape) output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
input_tensor_grad = self.backward_step( input_tensor_grad = self.backward_step(
engine, engine,
@ -340,3 +336,309 @@ class PipelineSchedule(BaseSchedule):
return tuple((torch.cat(return_tensors, dim=0), None, None)) return tuple((torch.cat(return_tensors, dim=0), None, None))
else: else:
return tuple((None, None, None)) return tuple((None, None, None))
class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True):
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
super().__init__(num_microbatches, sync_data=sync_data)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
gpc.set_virtual_pipeline_parallel_rank(0)
def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
if isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half
def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
"""
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
input_tensor = squeeze(input_tensor)
output_tensor = model(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_pipeline_last_stage():
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
else:
return_tensors.append(output_tensor)
return output_tensor
else:
return output_tensor
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter)
model = engine.model
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
return_tensors = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
# Used for tensor meta information communication
input_tensor_shapes = [None for _ in range(len(model))]
output_tensor_shapes = [None for _ in range(len(model))]
send_tensor_shape_flags = [True for _ in range(len(model))]
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = self.num_microbatches * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if self.num_microbatches == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
# forward step
if gpc.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = self.forward_step(
engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
if gpc.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
gpc.set_virtual_pipeline_parallel_rank(0)
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = forward_step_helper(k)
if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
output_tensor, send_tensor_shape_flags[model_chunk_id])
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if gpc.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if gpc.is_pipeline_last_stage():
output_tensor = None
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
input_tensor_shapes[next_forward_model_chunk_id])
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
send_forward_recv_forward(
output_tensor,
input_shape,
recv_prev=recv_prev,
dtype=self.dtype)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
if gpc.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
if gpc.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if gpc.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate tensors.
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
recv_backward(output_tensor_shapes[num_model_chunks-1]))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(
input_tensor_grad,
output_shape,
recv_next=recv_next,
dtype=self.dtype))
if len(return_tensors) > 0:
if return_loss:
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))

View File

@ -280,11 +280,21 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
raise ConfigException( raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time") "It is not allowed to set fp16 and zero configuration in your config file at the same time")
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0:
if zero_cfg is not None:
raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
# initialize amp # initialize amp
amp_mode = None amp_mode = None
if fp16_cfg is not None and fp16_cfg.mode is not None: if fp16_cfg is not None and fp16_cfg.mode is not None:
# TODO: pipeline only support NAIVE AMP
cfg_ = fp16_cfg.copy() cfg_ = fp16_cfg.copy()
amp_mode = cfg_.pop('mode') amp_mode = cfg_.pop('mode')
if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model, model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
@ -357,16 +367,6 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
gradient_handlers=gradient_handlers, gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0:
if zero_cfg is not None:
raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
raise ConfigException(
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
engine = Engine( engine = Engine(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,

View File

@ -3,7 +3,7 @@ from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0,
is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp, is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp,
is_using_pp, conditional_context, is_model_parallel_parameter, is_using_pp, conditional_context, is_model_parallel_parameter,
clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes,
param_is_not_tensor_parallel_duplicate) param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank)
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
from .memory import report_memory_usage from .memory import report_memory_usage
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
@ -22,5 +22,6 @@ __all__ = ['checkpoint',
'Timer', 'MultiTimer', 'Timer', 'MultiTimer',
'multi_tensor_applier', 'multi_tensor_applier',
'accumulate_gradient', 'accumulate_gradient',
'DataParallelSampler', 'get_dataloader' 'DataParallelSampler', 'get_dataloader',
'switch_virtual_pipeline_parallel_rank'
] ]

View File

@ -249,3 +249,13 @@ def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and return (hasattr(param, IS_TENSOR_PARALLEL) and
getattr(param, IS_TENSOR_PARALLEL)) or ( getattr(param, IS_TENSOR_PARALLEL)) or (
gpc.get_local_rank(ParallelMode.TENSOR) == 0) gpc.get_local_rank(ParallelMode.TENSOR) == 0)
@contextmanager
def switch_virtual_pipeline_parallel_rank(rank):
prev_rank = gpc.virtual_pipeline_parallel_rank
try:
gpc.set_virtual_pipeline_parallel_rank(rank)
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)

View File

@ -172,10 +172,10 @@ elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the 2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the
`MODELS` registry. Define the `self.layers_cfg` attribute. `MODELS` registry. Define the `self.layers_cfg` attribute.
Pass in a dict/Config object which specifies the parameters of your model. Pass in a dict/Config object which specifies the parameters of your model.
Use `colossalai.builder.pipeline.PipelineModelInitializer` to partition the layers. Use `colossalai.builder.pipeline.build_pipeline_model_from_cfg` to partition the layers.
```python ```python
from colossalai.builder import PipelineModelInitializer from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.nn.model import ModelFromConfig from colossalai.nn.model import ModelFromConfig
from colossalai.registry import MODELS from colossalai.registry import MODELS
@ -199,8 +199,11 @@ model_cfg = dict(
... ...
) )
initializer = PipelineModelInitializer(model_cfg, num_chunks=1) # from config
model = initializer.initialize() model = build_pipeline_model_from_cfg(model_cfg, num_chunks=1)
# from torch.nn.Sequential
# model = build_pipeline_model(sequential_model, num_model_chunks)
``` ```
@ -214,6 +217,9 @@ engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criteri
schedule = PipelineSchedule(num_microbatches=4) schedule = PipelineSchedule(num_microbatches=4)
# interleaved pipeline
# schedule = InterleavedPipelineSchedule(num_microbatches=4, num_model_chunks=2)
# execute a training epoch # execute a training epoch
data_iter = iter(train_dataloader) data_iter = iter(train_dataloader)

View File

@ -6,7 +6,7 @@ from colossalai.logging import get_dist_logger
import colossalai import colossalai
import torch import torch
import os import os
from colossalai.builder import PipelineModelInitializer from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, MultiTimer from colossalai.utils import get_dataloader, MultiTimer
from colossalai.nn.loss import CrossEntropyLoss2D from colossalai.nn.loss import CrossEntropyLoss2D
@ -50,8 +50,7 @@ def test_hybrid_parallel():
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w') # suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
# build vit-t-32 # build vit-t-32
initializer = PipelineModelInitializer(vit_t_2d.model_cfg, num_chunks=1) model = build_pipeline_model_from_cfg(vit_t_2d.model_cfg, num_chunks=1)
model = initializer.initialize()
# build dataloaders # build dataloaders
train_dataset = CIFAR10( train_dataset = CIFAR10(
@ -139,4 +138,4 @@ def test_hybrid_parallel():
if __name__ == '__main__': if __name__ == '__main__':
main() test_hybrid_parallel()

View File

@ -5,7 +5,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.builder.pipeline import PipelineModelInitializer from colossalai.builder.pipeline import build_pipeline_model_from_cfg
from colossalai.core import global_context from colossalai.core import global_context
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -28,7 +28,7 @@ def run_partition(rank, world_size):
logger.info('finished initialization') logger.info('finished initialization')
# build model # build model
model = PipelineModelInitializer(global_context.config.model, 1, verbose=True).initialize() model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
assert isinstance(model, torch.nn.Module) assert isinstance(model, torch.nn.Module)
logger.info('model is created') logger.info('model is created')

View File

@ -8,7 +8,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import model import model
from colossalai.builder import PipelineModelInitializer from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.communication import p2p as p2p_communication from colossalai.communication import p2p as p2p_communication
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
@ -39,7 +39,7 @@ def run_schedule(rank, world_size):
backend='nccl') backend='nccl')
# build model # build model
model = PipelineModelInitializer(gpc.config.model, 1).initialize() model = build_pipeline_model_from_cfg(gpc.config.model, 1)
print_rank_0('model is created') print_rank_0('model is created')
train_dataset = CIFAR10( train_dataset = CIFAR10(