From 8f02a88db27757e8f5453dc8c7fe1862f28c525f Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 20 Dec 2021 23:26:19 +0800 Subject: [PATCH] add interleaved pipeline, fix naive amp and update pipeline model initializer (#80) --- colossalai/amp/naive_amp/__init__.py | 12 +- colossalai/amp/naive_amp/_fp16_optimizer.py | 56 +-- colossalai/builder/__init__.py | 4 +- colossalai/builder/pipeline.py | 157 +++++---- colossalai/communication/p2p.py | 55 +-- colossalai/communication/utils.py | 10 +- colossalai/context/parallel_context.py | 20 ++ .../_data_parallel_gradient_handler.py | 2 +- colossalai/engine/schedule/__init__.py | 4 +- .../engine/schedule/_pipeline_schedule.py | 330 +++++++++++++++++- colossalai/initialize.py | 20 +- colossalai/utils/__init__.py | 5 +- colossalai/utils/common.py | 10 + docs/parallelization.md | 14 +- .../run_cifar10_vit2d_with_pipeline.py | 7 +- .../test_pipeline/test_partition.py | 4 +- .../test_pipeline/test_pipeline_schedule.py | 4 +- 17 files changed, 544 insertions(+), 170 deletions(-) diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index 08ae7b62a..c050ee937 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -20,10 +20,16 @@ def convert_to_naive_amp(model: nn.Module, :return: (model, optimizer) :rtype: Tuple """ - if is_no_pp_or_last_stage(): - model = NaiveAMPModel(model, output_to_fp32=True) + if isinstance(model, nn.ModuleList): + # interleaved pipeline + module_list = [] + for chunk, m in enumerate(model): + output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 + module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) + model = nn.ModuleList(module_list) else: - model = NaiveAMPModel(model, output_to_fp32=False) + output_to_fp32 = is_no_pp_or_last_stage() + model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) optimizer = NaiveAMPOptimizer(optimizer, **amp_config) return model, optimizer diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index 9ac8543c1..d34143aec 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -14,7 +14,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes, - clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier) + clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier, is_using_pp) def _zero_grad_group_helper(group, set_to_none): @@ -58,7 +58,8 @@ class DynamicGradScaler: backoff_factor, growth_interval, hysteresis, - max_scale: int = None): + max_scale: int = None, + verbose: bool = False): """"Grad scaler with dynamic scale that gets adjusted during training.""" assert initial_scale > 0.0 @@ -91,6 +92,7 @@ class DynamicGradScaler: self._hysteresis_tracker = self.hysteresis self._logger = get_dist_logger() + self.verbose = verbose @property def scale(self): @@ -111,7 +113,8 @@ class DynamicGradScaler: if self._hysteresis_tracker <= 0: self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale) - self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) + if self.verbose: + self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) else: # If there is no nan/inf, increment the growth tracker. self._growth_tracker += 1 @@ -122,11 +125,14 @@ class DynamicGradScaler: self._hysteresis_tracker = self.hysteresis # and scale up the loss scale. if self._max_scale is not None and self._scale >= self._max_scale: - self._logger.info( - f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) + if self.verbose: + self._logger.info( + f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) else: self._scale = self._scale * self.growth_factor - self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) + if self.verbose: + self._logger.info( + f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) def state_dict(self): state_dict = {} @@ -162,6 +168,8 @@ class FP16Optimizer(Optimizer): :type hysterisis: int :param max_scale: maximum loss scale allowed :type max_scale: int + :param verbose: if set to `True`, will print debug info + :type verbose: bool """ def __init__(self, @@ -174,27 +182,29 @@ class FP16Optimizer(Optimizer): backoff_factor=0.5, growth_interval=1000, hysteresis=2, - max_scale: int = 2 ** 32): + max_scale: int = 2 ** 32, + verbose: bool = False): # default args for compatibility bf16 = False - params_have_main_grad = True + params_have_main_grad = False # have a defaults for compatibility with pytorch optim self.defaults = optimizer.defaults # log config self._logger = get_dist_logger() - self._logger.info(f"\n========= FP16 Optimizer Config =========\n" - f"Optimizer: {optimizer.__class__.__name__}\n" - f"clip_grad = {clip_grad}\n" - f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n" - f"initial_scale = {initial_scale}\n" - f"min_scale = {min_scale}\n" - f"growth_factor = {growth_factor}\n" - f"backoff_factor = {backoff_factor}\n" - f"growth_interval = {growth_interval}\n" - f"hysteresis = {hysteresis}\n" - f"==========================================", ranks=[0]) + if verbose: + self._logger.info(f"\n========= FP16 Optimizer Config =========\n" + f"Optimizer: {optimizer.__class__.__name__}\n" + f"clip_grad = {clip_grad}\n" + f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n" + f"initial_scale = {initial_scale}\n" + f"min_scale = {min_scale}\n" + f"growth_factor = {growth_factor}\n" + f"backoff_factor = {backoff_factor}\n" + f"growth_interval = {growth_interval}\n" + f"hysteresis = {hysteresis}\n" + f"==========================================", ranks=[0]) """Input optimizer is the base optimizer for example Adam.""" self.optimizer = optimizer @@ -212,7 +222,8 @@ class FP16Optimizer(Optimizer): backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, - max_scale=max_scale + max_scale=max_scale, + verbose=verbose ) # None grad scaler is only supported for bf16. @@ -350,6 +361,11 @@ class FP16Optimizer(Optimizer): op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.TENSOR)) + if is_using_pp(): + torch.distributed.all_reduce(self.found_inf, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PIPELINE)) + # Check for nan. found_inf_flag = (self.found_inf.item() > 0) return found_inf_flag diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py index 6c1105a2d..4caef0c96 100644 --- a/colossalai/builder/__init__.py +++ b/colossalai/builder/__init__.py @@ -1,10 +1,10 @@ from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, build_gradient_handler) -from .pipeline import PipelineModelInitializer +from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg __all__ = [ 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', - 'build_gradient_handler', 'PipelineModelInitializer' + 'build_gradient_handler', 'build_pipeline_model', 'build_pipeline_model_from_cfg' ] diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py index a859030a7..3e545ebb3 100644 --- a/colossalai/builder/pipeline.py +++ b/colossalai/builder/pipeline.py @@ -1,11 +1,12 @@ import copy import heapq + from colossalai.builder import build_model, build_layer from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import set_to_cuda +import torch.nn as nn def _binary_partition(weights, st, ed): @@ -150,7 +151,19 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks): return parts -class PipelineModelInitializer(): +def _count_layer_params(layers): + """Count the number of parameters in each layer + """ + param_counts = [0] * len(layers) + for idx, cfg in enumerate(layers): + layer = build_layer(cfg) + params = filter(lambda p: p.requires_grad, layer.parameters()) + param_counts[idx] = sum(p.numel() for p in params) + + return param_counts + + +def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False): """An intializer to split the model into different stages for pipeline parallelism. An example for the model config is shown below. The class VisionTransformerFromConfig should @@ -168,88 +181,86 @@ class PipelineModelInitializer(): :param num_chunks: the number of chunks you want to have on the current stage. This value should be 1 in most cases unless you are using virutal pipeline parallelism. :type num_chunks: int + :param partition_method: this parameter determines how you want to split your model layers into stages, + you can set it as 'layer' or 'parameter' + :type partition_method: str :param verbose: whether to print the logs :type verbose: bool - """ + ori_model = build_model(config) + layers = ori_model.layers_cfg + layer_length = len(layers) + logger = get_dist_logger() + if verbose: + logger.info(f"The total length of layers is {layer_length}", ranks=[0]) - def __init__(self, config, num_chunks, verbose=False): - self.num_chunks = num_chunks - self.ori_model = build_model(config) - self.layers = self.ori_model.layers_cfg - layer_length = len(self.layers) - self.verbose = verbose - self._logger = get_dist_logger() - self._logger.info(f"The total length of layers is {layer_length}", ranks=[0]) + pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - def initialize(self, partition_method='parameter'): - """Initialize the model object from the config passed + method = partition_method.lower() + # Make a partition + if method == 'layer': + num_layers = len(layers) + parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks) + elif method == 'parameter': + param_counts = _count_layer_params(layers) + # print_rank_0(param_counts) + parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks) + else: + raise ValueError("Method should be a pre-set string in [layer, parameter]") - :param partition_method: this parameter determines how you want to split your model layers into stages, - you can set it as 'layer' or 'parameter' - :type partition_method: str - - """ - # Some space for initializing comunication groups - self._interval = None - self._partition_layers(method=partition_method) - models = self._build() - model = set_to_cuda(models) + # Display the partition + if verbose: + log_str = 'Layer allocation after partitioning: \n' + for stage in range(pipeline_parallel_size): - return model + num_layers = 0 + for st, ed in parts[stage]: + num_layers += ed - st - def _partition_layers(self, method): - pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' + for st, ed in parts[stage]: + for idx, layer in enumerate(layers[st: ed]): + log_str += f'\t{idx + st:2d}: {layer}\n' + logger.info(log_str, ranks=[0]) - method = method.lower() - # Make a partition - if method == 'layer': - num_layers = len(self.layers) - self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks) - elif method == 'parameter': - param_counts = self._count_layer_params() - # print_rank_0(param_counts) - self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks) - else: - raise ValueError("Method should be a pre-set string in [layer, parameter]") + # Save the partition + interval = parts[pipeline_rank] - # Display the partition - if gpc.get_global_rank() == 0 and self.verbose: - log_str = 'Layer allocation after partitioning: \n' - for stage in range(pipeline_parallel_size): + models = [] + for st, ed in interval: + model = copy.deepcopy(ori_model) + model.build_from_cfg(st, ed) + models.append(model) - num_layers = 0 - for st, ed in self.parts[stage]: - num_layers += ed - st + return nn.ModuleList(models) if len(models) > 1 else models[0] - log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' - for st, ed in self.parts[stage]: - for idx, layer in enumerate(self.layers[st: ed]): - log_str += f'\t{idx + st:2d}: {layer}\n' - self._logger.info(log_str, ranks=[0]) - # Save the partition - self._interval = self.parts[pipeline_rank] +def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False): + """An intializer to split the model into different stages for pipeline parallelism. + Note that `layer` must be `torch.nn.Sequential`. - def _build(self): - """Build model from the layer cfg according to the partition - """ - models = [] - for st, ed in self._interval: - model = copy.copy(self.ori_model) - model.build_from_cfg(st, ed) - models.append(model) - - return models - - def _count_layer_params(self): - """Count the number of parameters in each layer - """ - param_counts = [0] * len(self.layers) - for idx, cfg in enumerate(self.layers): - layer = build_layer(cfg) - params = filter(lambda p: p.requires_grad, layer.parameters()) - param_counts[idx] = sum(p.numel() for p in params) - - return param_counts + :param layers: layers of model + :type config: `torch.nn.Sequential` + :param num_chunks: the number of chunks you want to have on the current stage. This value should be 1 + in most cases unless you are using virutal pipeline parallelism. + :type num_chunks: int + :param verbose: whether to print the logs + :type verbose: bool + """ + pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks) + module_list = [] + for start, end in partitions[pipeline_rank]: + module_list.append(nn.Sequential(*layers[start:end])) + if verbose: + logger = get_dist_logger() + logger.info(f'Total {len(layers)} layers', ranks=[0]) + for rank, part in enumerate(partitions): + log_str = f'===== stage={rank} =====\n' + for chunk, (start, end) in enumerate(part): + log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n' + log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' + logger.info(log_str, ranks=[0]) + return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 1d0009d6a..99ccdf6eb 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -63,9 +63,6 @@ def _communicate(tensor_send_next=None, next_rank = gpc.get_next_global_rank( ParallelMode.PIPELINE) - # rank = dist.get_rank() - rank = gpc.get_global_rank() - ops = [] if tensor_send_prev is not None: send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank) @@ -88,7 +85,7 @@ def _communicate(tensor_send_next=None, return tensor_recv_prev, tensor_recv_next -def recv_forward(input_tensor_shape, prev_rank=None): +def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float): """Receives the input tensor from the previous member in pipeline. :param input_tensor_shape: The shape of the tensor to be recieved @@ -98,16 +95,17 @@ def recv_forward(input_tensor_shape, prev_rank=None): :return: The input tensor in forward step :rtype: :class:`torch.Tensor` """ - if gpc.is_first_rank(ParallelMode.PIPELINE): + if gpc.is_pipeline_first_stage(): input_tensor = None else: input_tensor, _ = _communicate(recv_prev=True, recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank) + prev_rank=prev_rank, + dtype=dtype) return input_tensor -def recv_backward(output_grad_shape, next_rank=None): +def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float): """Receives the grad tensor from the next member in pipeline. :param output_grad_shape: The shape of the tensor to be recieved @@ -117,12 +115,13 @@ def recv_backward(output_grad_shape, next_rank=None): :return: The grad of output tensor in forward step :rtype: :class:`torch.Tensor` """ - if gpc.is_last_rank(ParallelMode.PIPELINE): + if gpc.is_pipeline_last_stage(): output_tensor_grad = None else: _, output_tensor_grad = _communicate(recv_next=True, recv_next_shape=output_grad_shape, - next_rank=next_rank) + next_rank=next_rank, + dtype=dtype) return output_tensor_grad @@ -134,7 +133,7 @@ def send_forward(output_tensor, next_rank=None): :type output_tensor: :class:`torch.Tensor` :type next_rank: int, optional """ - if not gpc.is_last_rank(ParallelMode.PIPELINE): + if not gpc.is_pipeline_last_stage(): _communicate(tensor_send_next=output_tensor, next_rank=next_rank) @@ -147,7 +146,7 @@ def send_backward(input_tensor_grad, prev_rank=None): :type input_tensor_grad: :class:`torch.Tensor` :type prev_rank: int, optional """ - if not gpc.is_first_rank(ParallelMode.PIPELINE): + if not gpc.is_pipeline_first_stage(): _communicate(tensor_send_prev=input_tensor_grad, prev_rank=prev_rank) @@ -155,7 +154,8 @@ def send_backward(input_tensor_grad, prev_rank=None): def send_forward_recv_backward(output_tensor, output_grad_shape, recv_next=True, - next_rank=None): + next_rank=None, + dtype=torch.float): """Batched communication operation. Sends the input tensor to the next member in pipeline, while recieves the grad tensor from the next member in pipeline. @@ -167,20 +167,22 @@ def send_forward_recv_backward(output_tensor, :return: The grad of output tensor in forward step :rtype: :class:`torch.Tensor` """ - if gpc.is_last_rank(ParallelMode.PIPELINE): + if gpc.is_pipeline_last_stage(): output_tensor_grad = None else: _, output_tensor_grad = _communicate(tensor_send_next=output_tensor, recv_next=recv_next, recv_next_shape=output_grad_shape, - next_rank=next_rank) + next_rank=next_rank, + dtype=dtype) return output_tensor_grad def send_backward_recv_forward(input_tensor_grad, input_tensor_shape, recv_prev=True, - prev_rank=None): + prev_rank=None, + dtype=torch.float): """Batched communication operation. Sends the grad tensor to the previous member in pipeline, while recieves the input tensor from the previous member in pipeline. @@ -192,13 +194,14 @@ def send_backward_recv_forward(input_tensor_grad, :return: The input tensor in forward step :rtype: :class:`torch.Tensor` """ - if gpc.is_first_rank(ParallelMode.PIPELINE): + if gpc.is_pipeline_first_stage(): input_tensor = None else: input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank) + prev_rank=prev_rank, + dtype=dtype) return input_tensor @@ -206,7 +209,8 @@ def send_forward_recv_forward(output_tensor, input_tensor_shape, recv_prev=True, prev_rank=None, - next_rank=None): + next_rank=None, + dtype=torch.float): """Batched communication operation. Sends the input tensor to the next member in pipeline, while recieves the input tensor from the previous member in pipeline. @@ -222,7 +226,8 @@ def send_forward_recv_forward(output_tensor, recv_prev=recv_prev, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, - next_rank=next_rank) + next_rank=next_rank, + dtype=dtype) return input_tensor @@ -230,7 +235,8 @@ def send_backward_recv_backward(input_tensor_grad, output_grad_shape, recv_next=True, prev_rank=None, - next_rank=None): + next_rank=None, + dtype=torch.float): """Batched communication operation. Sends the grad tensor to the previous member in pipeline, while recieves the grad tensor from the next member in pipeline. @@ -246,7 +252,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next=recv_next, recv_next_shape=output_grad_shape, prev_rank=prev_rank, - next_rank=next_rank) + next_rank=next_rank, + dtype=dtype) return output_tensor_grad @@ -257,7 +264,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, recv_prev=True, recv_next=True, prev_rank=None, - next_rank=None): + next_rank=None, + dtype=torch.float): """Batched communication operation. Sends the input tensor to the next and the grad tensor to the previous, while recieves the grad tensor from the next and the input tensor from the previous. @@ -281,5 +289,6 @@ def send_forward_backward_recv_forward_backward(output_tensor, recv_prev_shape=input_tensor_shape, recv_next_shape=output_grad_shape, prev_rank=prev_rank, - next_rank=next_rank) + next_rank=next_rank, + dtype=dtype) return input_tensor, output_tensor_grad diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index a8dc0da1a..1eeba7bda 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -29,14 +29,8 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None): send_shape = torch.tensor(tensor.size(), **tensor_kwargs) send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) - ops = [ - dist.P2POp(dist.isend, send_ndims, next_rank), - dist.P2POp(dist.isend, send_shape, next_rank) - ] - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) return False diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 9d7d311c3..6e4e57858 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -53,6 +53,8 @@ class ParallelContext: self.data_parallel_size = 1 self.pipeline_parallel_size = 1 self.tensor_parallel_size = 1 + self.virtual_pipeline_parallel_size = None + self.virtual_pipeline_parallel_rank = None # logging self._verbose = False @@ -205,6 +207,18 @@ class ParallelContext: world_size = self.get_world_size(parallel_mode) return rank == world_size - 1 + def is_pipeline_first_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0: + return False + return self.is_first_rank(ParallelMode.PIPELINE) + + def is_pipeline_last_stage(self, ignore_virtual=False): + if not ignore_virtual: + if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + return False + return self.is_last_rank(ParallelMode.PIPELINE) + def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`. @@ -494,3 +508,9 @@ class ParallelContext: self._logger.info( 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', ranks=[0]) + + def set_virtual_pipeline_parallel_size(self, size): + self.virtual_pipeline_parallel_size = size + + def set_virtual_pipeline_parallel_rank(self, rank): + self.virtual_pipeline_parallel_rank = rank diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py index 9fa414cfd..d29abb2d3 100644 --- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -32,7 +32,7 @@ class DataParallelGradientHandler(BaseGradientHandler): if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) - param.main_grad = param.grad + # param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. for tp in buckets: diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py index a885a672e..9c8f00f40 100644 --- a/colossalai/engine/schedule/__init__.py +++ b/colossalai/engine/schedule/__init__.py @@ -1,5 +1,5 @@ from ._base_schedule import BaseSchedule -from ._pipeline_schedule import PipelineSchedule +from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule from ._non_pipeline_schedule import NonPipelineSchedule -__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule'] +__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule', 'InterleavedPipelineSchedule'] diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index c637622a1..ad1aba27b 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -13,9 +13,8 @@ from colossalai.core import global_context as gpc from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from colossalai.utils import get_current_device +from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank from ._base_schedule import BaseSchedule -from colossalai.amp import AMP_TYPE def squeeze(x: Union[Tensor, tuple, list]): @@ -47,6 +46,7 @@ class PipelineSchedule(BaseSchedule): self.num_microbatches = num_microbatches self.sync_data = sync_data + self.dtype = torch.float def _move_to_device(self, data): if isinstance(data, ( @@ -122,12 +122,8 @@ class PipelineSchedule(BaseSchedule): "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" ) - # LSG: set default dtype to fp16 for communication if isinstance(engine.model, NaiveAMPModel): - torch.set_default_dtype(torch.half) - self.logger.warning( - 'default tensor dtype is set to torch.half for fp16 training', - ranks=[0]) + self.dtype = torch.half def forward_step(self, engine, input_tensor, return_tensors, return_loss=True): """Forward step for passed-in model. If it is the first stage, the input tensor @@ -140,7 +136,7 @@ class PipelineSchedule(BaseSchedule): :type input_tensor: :class:`torch.Tensor` :param return_tensors: a list of tensors to return :type return_tensors: List[:class:`torch.Tensor`] - + :return: output or the loss value of the current pipeline stage :rtype: :class:`torch.Tensor` """ @@ -252,7 +248,7 @@ class PipelineSchedule(BaseSchedule): for i in range(num_warmup_microbatches): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = recv_tensor_meta(ft_shape) - input_tensor = recv_forward(ft_shape) + input_tensor = recv_forward(ft_shape, dtype=self.dtype) output_tensor = self.forward_step( engine, input_tensor, return_tensors, return_loss=return_loss @@ -272,7 +268,7 @@ class PipelineSchedule(BaseSchedule): if num_microbatches_remaining > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = recv_tensor_meta(ft_shape) - input_tensor = recv_forward(ft_shape) + input_tensor = recv_forward(ft_shape, dtype=self.dtype) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): @@ -286,11 +282,11 @@ class PipelineSchedule(BaseSchedule): send_forward(output_tensor) if not last_iteration: - input_tensor = recv_forward(ft_shape) + input_tensor = recv_forward(ft_shape, dtype=self.dtype) else: output_tensor_grad = send_forward_recv_backward( - output_tensor, bt_shape) + output_tensor, bt_shape, dtype=self.dtype) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) @@ -312,7 +308,7 @@ class PipelineSchedule(BaseSchedule): send_backward(input_tensor_grad) else: input_tensor = send_backward_recv_forward( - input_tensor_grad, ft_shape) + input_tensor_grad, ft_shape, dtype=self.dtype) # Run cooldown backward passes. if not forward_only: @@ -320,7 +316,7 @@ class PipelineSchedule(BaseSchedule): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - output_tensor_grad = recv_backward(bt_shape) + output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype) input_tensor_grad = self.backward_step( engine, @@ -340,3 +336,309 @@ class PipelineSchedule(BaseSchedule): return tuple((torch.cat(return_tensors, dim=0), None, None)) else: return tuple((None, None, None)) + + +class InterleavedPipelineSchedule(PipelineSchedule): + def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True): + assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ + 'num_microbatches must be an integer multiple of pipeline parallel world size' + super().__init__(num_microbatches, sync_data=sync_data) + gpc.set_virtual_pipeline_parallel_size(num_model_chunks) + gpc.set_virtual_pipeline_parallel_rank(0) + + def pre_processing(self, engine): + if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): + raise TypeError( + "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" + ) + + if isinstance(engine.model[0], NaiveAMPModel): + self.dtype = torch.half + + def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_tensor is used. + Returns output tensor. This is a helper function and can be ignored by users. + """ + + if input_tensor is None: + input_tensor, label = self.load_micro_batch() + input_tensor = squeeze(input_tensor) + output_tensor = model(input_tensor) + output_tensor = squeeze(output_tensor) + + if gpc.is_pipeline_last_stage(): + if return_loss: + input_tensor, label = self.load_micro_batch() + loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches + return_tensors.append( + tuple((output_tensor, label[0], loss_reduced))) + return loss_reduced + else: + return_tensors.append(output_tensor) + return output_tensor + else: + return output_tensor + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + model = engine.model + input_tensors = [[] for _ in range(len(model))] + output_tensors = [[] for _ in range(len(model))] + return_tensors = [] + if not forward_only: + output_tensor_grads = [[] for _ in range(len(model))] + + # Used for tensor meta information communication + input_tensor_shapes = [None for _ in range(len(model))] + output_tensor_shapes = [None for _ in range(len(model))] + send_tensor_shape_flags = [True for _ in range(len(model))] + + pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + num_microbatches = self.num_microbatches * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if self.num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = \ + (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += ( + num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, + num_microbatches) + num_microbatches_remaining = \ + num_microbatches - num_warmup_microbatches + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = (num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def forward_step_helper(microbatch_id): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + # forward step + if gpc.is_pipeline_first_stage(): + if len(input_tensors[model_chunk_id]) == \ + len(output_tensors[model_chunk_id]): + input_tensors[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id][-1] + output_tensor = self.forward_step( + engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss) + output_tensors[model_chunk_id].append(output_tensor) + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_tensors[model_chunk_id].pop() + output_tensors[model_chunk_id].pop() + + return output_tensor + + def backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + if gpc.is_pipeline_last_stage(): + if len(output_tensor_grads[model_chunk_id]) == 0: + output_tensor_grads[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id].pop(0) + output_tensor = output_tensors[model_chunk_id].pop(0) + output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad) + + return input_tensor_grad + + # Run warmup forward passes. + gpc.set_virtual_pipeline_parallel_rank(0) + if not gpc.is_pipeline_first_stage(): + input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0]) + input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype)) + + for k in range(num_warmup_microbatches): + model_chunk_id = get_model_chunk_id(k, forward=True) + output_tensor = forward_step_helper(k) + if not gpc.is_pipeline_last_stage(): + output_tensor_shapes[model_chunk_id] = output_tensor.shape + send_tensor_shape_flags[model_chunk_id] = send_tensor_meta( + output_tensor, send_tensor_shape_flags[model_chunk_id]) + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if gpc.is_pipeline_last_stage(): + output_tensor = None + + with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): + if not gpc.is_pipeline_first_stage(): + input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta( + input_tensor_shapes[next_forward_model_chunk_id]) + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None + input_tensor, output_tensor_grad = \ + send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, + input_shape, + output_shape, + recv_prev=recv_prev, recv_next=recv_next, + dtype=self.dtype) + output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + else: + input_tensor = \ + send_forward_recv_forward( + output_tensor, + input_shape, + recv_prev=recv_prev, + dtype=self.dtype) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + output_tensor = forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + # Send output_tensor and input_tensor_grad, receive input_tensor + # and output_tensor_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) + if gpc.is_pipeline_last_stage(): + output_tensor = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) + if gpc.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, + forward=True) + + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, + forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None + output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None + # Communicate tensors. + input_tensor, output_tensor_grad = \ + send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, + input_shape, + output_shape, + recv_prev=recv_prev, recv_next=recv_next, + dtype=self.dtype) + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + output_tensor_grad) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if all_warmup_microbatches: + output_tensor_grads[num_model_chunks-1].append( + recv_backward(output_tensor_shapes[num_model_chunks-1])) + for k in range(num_microbatches_remaining, num_microbatches): + input_tensor_grad = backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (num_microbatches - 1): + recv_next = False + output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None + output_tensor_grads[next_backward_model_chunk_id].append( + send_backward_recv_backward( + input_tensor_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype)) + + if len(return_tensors) > 0: + if return_loss: + output, label, loss = tuple(map(list, zip(*return_tensors))) + return (torch.cat(output, dim=0), + torch.cat(label, dim=0), + sum(loss)) + else: + return tuple((torch.cat(return_tensors, dim=0), None, None)) + else: + return tuple((None, None, None)) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 8817daf8c..01d5b3d2d 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -280,11 +280,21 @@ def initialize(model: Union[nn.Module, List[nn.Module]], raise ConfigException( "It is not allowed to set fp16 and zero configuration in your config file at the same time") + # clip grad norm + clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) + if clip_grad_norm > 0: + if zero_cfg is not None: + raise ConfigException( + "clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration") + # initialize amp amp_mode = None if fp16_cfg is not None and fp16_cfg.mode is not None: + # TODO: pipeline only support NAIVE AMP cfg_ = fp16_cfg.copy() amp_mode = cfg_.pop('mode') + if amp_mode == AMP_TYPE.NAIVE: + cfg_['clip_grad'] = clip_grad_norm model, optimizer, criterion = convert_to_amp(model=model, optimizer=optimizer, criterion=criterion, @@ -357,16 +367,6 @@ def initialize(model: Union[nn.Module, List[nn.Module]], gradient_handlers=gradient_handlers, lr_scheduler=lr_scheduler) - # clip grad norm - clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) - if clip_grad_norm > 0: - if zero_cfg is not None: - raise ConfigException( - "clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration") - elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE: - raise ConfigException( - "clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration") - engine = Engine( model=model, optimizer=optimizer, diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index b155ad0a3..7430ab100 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -3,7 +3,7 @@ from .common import (print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp, is_using_pp, conditional_context, is_model_parallel_parameter, clip_grad_norm_fp32, count_zeros_fp32, copy_tensor_parallel_attributes, - param_is_not_tensor_parallel_duplicate) + param_is_not_tensor_parallel_duplicate, switch_virtual_pipeline_parallel_rank) from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda from .memory import report_memory_usage from .timer import MultiTimer, Timer @@ -22,5 +22,6 @@ __all__ = ['checkpoint', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', - 'DataParallelSampler', 'get_dataloader' + 'DataParallelSampler', 'get_dataloader', + 'switch_virtual_pipeline_parallel_rank' ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 2baa41a49..610986d03 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -249,3 +249,13 @@ def param_is_not_tensor_parallel_duplicate(param): return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or ( gpc.get_local_rank(ParallelMode.TENSOR) == 0) + + +@contextmanager +def switch_virtual_pipeline_parallel_rank(rank): + prev_rank = gpc.virtual_pipeline_parallel_rank + try: + gpc.set_virtual_pipeline_parallel_rank(rank) + yield + finally: + gpc.set_virtual_pipeline_parallel_rank(prev_rank) diff --git a/docs/parallelization.md b/docs/parallelization.md index 1fdf2cec6..595925957 100644 --- a/docs/parallelization.md +++ b/docs/parallelization.md @@ -172,10 +172,10 @@ elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: 2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the `MODELS` registry. Define the `self.layers_cfg` attribute. Pass in a dict/Config object which specifies the parameters of your model. -Use `colossalai.builder.pipeline.PipelineModelInitializer` to partition the layers. +Use `colossalai.builder.pipeline.build_pipeline_model_from_cfg` to partition the layers. ```python -from colossalai.builder import PipelineModelInitializer +from colossalai.builder import build_pipeline_model_from_cfg from colossalai.nn.model import ModelFromConfig from colossalai.registry import MODELS @@ -199,8 +199,11 @@ model_cfg = dict( ... ) -initializer = PipelineModelInitializer(model_cfg, num_chunks=1) -model = initializer.initialize() +# from config +model = build_pipeline_model_from_cfg(model_cfg, num_chunks=1) + +# from torch.nn.Sequential +# model = build_pipeline_model(sequential_model, num_model_chunks) ``` @@ -214,6 +217,9 @@ engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criteri schedule = PipelineSchedule(num_microbatches=4) +# interleaved pipeline +# schedule = InterleavedPipelineSchedule(num_microbatches=4, num_model_chunks=2) + # execute a training epoch data_iter = iter(train_dataloader) diff --git a/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py b/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py index 2953dcb5b..036ac81a8 100644 --- a/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py +++ b/tests/test_data_pipeline_tensor_parallel/run_cifar10_vit2d_with_pipeline.py @@ -6,7 +6,7 @@ from colossalai.logging import get_dist_logger import colossalai import torch import os -from colossalai.builder import PipelineModelInitializer +from colossalai.builder import build_pipeline_model_from_cfg from colossalai.core import global_context as gpc from colossalai.utils import get_dataloader, MultiTimer from colossalai.nn.loss import CrossEntropyLoss2D @@ -50,8 +50,7 @@ def test_hybrid_parallel(): # suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w') # build vit-t-32 - initializer = PipelineModelInitializer(vit_t_2d.model_cfg, num_chunks=1) - model = initializer.initialize() + model = build_pipeline_model_from_cfg(vit_t_2d.model_cfg, num_chunks=1) # build dataloaders train_dataset = CIFAR10( @@ -139,4 +138,4 @@ def test_hybrid_parallel(): if __name__ == '__main__': - main() + test_hybrid_parallel() diff --git a/tests/test_trainer/test_pipeline/test_partition.py b/tests/test_trainer/test_pipeline/test_partition.py index 9335a197b..9f011c0e2 100644 --- a/tests/test_trainer/test_pipeline/test_partition.py +++ b/tests/test_trainer/test_pipeline/test_partition.py @@ -5,7 +5,7 @@ import torch import torch.multiprocessing as mp from torch.utils.data import DataLoader -from colossalai.builder.pipeline import PipelineModelInitializer +from colossalai.builder.pipeline import build_pipeline_model_from_cfg from colossalai.core import global_context from colossalai.initialize import launch from colossalai.logging import get_dist_logger @@ -28,7 +28,7 @@ def run_partition(rank, world_size): logger.info('finished initialization') # build model - model = PipelineModelInitializer(global_context.config.model, 1, verbose=True).initialize() + model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True) assert isinstance(model, torch.nn.Module) logger.info('model is created') diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py index 5247fe7f0..be2f7ab30 100644 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -8,7 +8,7 @@ import torch import torch.multiprocessing as mp import model -from colossalai.builder import PipelineModelInitializer +from colossalai.builder import build_pipeline_model_from_cfg from colossalai.communication import p2p as p2p_communication from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta from colossalai.context.parallel_mode import ParallelMode @@ -39,7 +39,7 @@ def run_schedule(rank, world_size): backend='nccl') # build model - model = PipelineModelInitializer(gpc.config.model, 1).initialize() + model = build_pipeline_model_from_cfg(gpc.config.model, 1) print_rank_0('model is created') train_dataset = CIFAR10(