diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py new file mode 100644 index 000000000..d4183be3f --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -0,0 +1,149 @@ +from typing import Dict, List + +import torch +from torch import Tensor +from torch.nn import Parameter +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper + +from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin + + +class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + working_params: List[Parameter], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.params = working_params + + def check_local_overflow(self) -> bool: + for p in self.params: + if p.grad is not None and not torch.isfinite(p.grad).all(): + return True + return False + + +class MixedPrecisionOptimizer(OptimizerWrapper): + + def __init__(self, + optim: Optimizer, + precision: str = 'fp16', + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0): + super().__init__(optim) + if precision == 'fp16': + working_params = [] + for group in self.optim.param_groups: + for p in group['params']: + working_params.append(p) + self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif precision == 'bf16': + self.mixed_precision = BF16MixedPrecisionMixin() + else: + raise ValueError(f'Unsupported precision: {precision}') + if max_norm > 0.0: + raise NotImplementedError('max_norm is not supported yet.') + self.max_norm = max_norm + self.working_to_master_map: Dict[Parameter, Tensor] = {} + self.master_to_working_map: Dict[Tensor, Parameter] = {} + + # create master weights + for group in self.optim.param_groups: + master_params = [] + for p in group['params']: + if p.requires_grad: + master_p = p + if p.dtype != torch.float: + master_p = p.detach().float() + self.working_to_master_map[p] = master_p + self.master_to_working_map[master_p] = p + master_params.append(master_p) + group['params'] = master_params + + def backward(self, loss: Tensor, *args, **kwargs): + loss = self.mixed_precision.pre_backward(loss) + loss.backward(*args, **kwargs) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def zero_grad(self, *args, **kwargs): + for p in self.working_to_master_map.keys(): + p.grad = None + self.mixed_precision.pre_zero_grad() + return super().zero_grad(*args, **kwargs) + + def _unscale_and_clip_grads(self, total_norm: float) -> None: + div_scale = 1.0 + if self.mixed_precision is not None: + div_scale = self.mixed_precision.get_grad_div_scale() + + if self.max_norm > 0.: + # norm is in fact norm*scale + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm + if clip > 1: + div_scale = clip * div_scale + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + p.grad.data.mul_(1. / div_scale) + + def _compute_grad_norm(self) -> float: + if self.max_norm <= 0.: + return 0. + grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None] + if len(grads) == 0: + return 0. + device = grads[0].device + # TODO(ver217): support tp + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) + return total_norm.item() + + def step(self, *args, **kwargs): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + # prepare grads + for group in self.optim.param_groups: + for p in group['params']: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + if working_param.grad is None: + p.grad = working_param.grad.data.float() + working_param.grad = None + total_norm = self._compute_grad_norm() + self._unscale_and_clip_grads(total_norm) + self.optim.step(*args, **kwargs) + # update working params + for group in self.optim.param_groups: + for p in group['params']: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + working_param.data.copy_(p.data) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index ec3dc7fc1..8a28b1286 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,6 +1,6 @@ import warnings from contextlib import contextmanager -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Union import torch import torch.nn as nn @@ -14,6 +14,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin +from .plugin.pp_plugin_base import PipelinePluginBase __all__ = ['Booster'] @@ -144,14 +145,15 @@ class Booster: def execute_pipeline(self, data_iter: Iterator, model: nn.Module, - criterion: Callable[[torch.Tensor], torch.Tensor], + criterion: Callable[[Any, Any], torch.Tensor], optimizer: Optimizer, return_loss: bool = True, - return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]: - # TODO: implement this method + return_outputs: bool = False) -> dict: # run pipeline forward backward pass # return loss or outputs if needed - pass + assert isinstance(self.plugin, + PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.' + return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: """Context manager to disable gradient synchronization across DP process groups. diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index a3b87b5f1..f48bf38bd 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,9 +1,10 @@ from .gemini_plugin import GeminiPlugin +from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin'] +__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin'] import torch from packaging import version diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py new file mode 100644 index 000000000..37badb613 --- /dev/null +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -0,0 +1,316 @@ +import random +from contextlib import nullcontext +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.zero.low_level import LowLevelZeroOptimizer + +from .pp_plugin_base import PipelinePluginBase + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class HybridParallelModule(ModelWrapper): + + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: + self.stage_manager = shard_config.pipeline_stage_manager + self.dp_group = dp_group + shardformer = ShardFormer(shard_config) + module, self.shared_params = shardformer.optimize(module) + # TODO(ver217): add input type cast + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + if precision == 'fp16': + module = module.half().cuda() + elif precision == 'bf16': + module = module.to(dtype=torch.bfloat16).cuda() + # TODO(ver217): support TP+DP + super().__init__(module) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + + def no_sync(self) -> Iterator[None]: + # no sync grads across data parallel + return nullcontext() + + def sync_grads(self): + # sync grad across data parallel + if self.dp_group.size() == 1: + return + for p in self.module.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, group=self.dp_group) + + +def init_pipeline_optimizer(optim: Optimizer, model: Module): + params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group['params'] if p in params] + new_param_groups.append({**group, 'params': params}) + optim.__setstate__({'param_groups': new_param_groups}) + + +class HybridParallelOptimizer(MixedPrecisionOptimizer): + + def __init__(self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + precision: str = 'fp16', + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0): + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, max_norm) + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None): + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, + overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, + forced_dtype) + + +class HybridParallelPlugin(PipelinePluginBase): + + def __init__( + self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + cpu_offload: bool = False, + enable_fused_normalization: bool = False, + num_microbatches: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + ) -> None: + super().__init__() + assert dist.get_world_size() % ( + tp_size * pp_size + ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + # TODO(ver217): support zero + assert zero_stage == 0, 'zero is not support yet' + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_fused_normalization = enable_fused_normalization + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.stage_manager = None + self.schedule = None + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' + assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_fused_normalization=self.enable_fused_normalization) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + self.max_norm = max_norm + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def supported_precisions(self) -> List[str]: + return ['fp16', 'bf16'] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + + def support_no_sync(self) -> bool: + return False + + def control_checkpoint_io(self) -> bool: + return True + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + if not isinstance(model, ModelWrapper): + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + optimizer = HybridParallelOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + else: + optimizer = HybridParallelZeroOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + partition_grad=(self.zero_stage == 2), + cpu_offload=self.cpu_offload, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.amp_config) + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline(self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer], + return_loss: bool = True, + return_outputs: bool = False) -> dict: + assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' + # return loss or outputs if needed + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + with ctx: + outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, + return_outputs) + # model.sync_shared_params() + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_grad() + else: + model.sync_grads() + return outputs + + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, + num_replicas=self.pg_mesh.size(DP_AXIS), + rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def get_checkpoint_io(self) -> CheckpointIO: + return None + + def no_sync(self, model: Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py new file mode 100644 index 000000000..67ade9330 --- /dev/null +++ b/colossalai/booster/plugin/pp_plugin_base.py @@ -0,0 +1,21 @@ +from abc import abstractmethod +from typing import Any, Callable, Iterator + +import torch + +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .plugin_base import Plugin + + +class PipelinePluginBase(Plugin): + + @abstractmethod + def execute_pipeline(self, + data_iter: Iterator, + model: ModelWrapper, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: OptimizerWrapper, + return_loss: bool = True, + return_outputs: bool = False) -> dict: + pass diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 851a0b595..f741b8363 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -7,9 +7,9 @@ from typing import Any, List, Optional, Union import torch import torch.distributed as dist +from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d -from version_parser.version import Version from .stage_manager import PipelineStageManager diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7bc626fe6..e1ed5f646 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -223,9 +223,6 @@ class LlamaPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -311,9 +308,6 @@ class LlamaPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = LlamaPipelineForwards.llama_model_forward( self.model, diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index b32c285bd..ae8cd8c6e 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import torch.nn as nn from torch import Tensor @@ -39,8 +39,8 @@ class ModelSharder(object): self._preprocess() # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) shared_params = self.policy.get_shared_params() - self._release_unheld_layers() - self._replace_module() + held_layers = self._release_unheld_layers() + self._replace_module(include=held_layers) self._materialize() self._postprocess() return shared_params @@ -51,7 +51,7 @@ class ModelSharder(object): def _postprocess(self) -> None: self.model = self.policy.postprocess() - def _replace_module(self,) -> None: + def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None: r""" Replace the module according to the policy, and replace the module one by one @@ -64,8 +64,13 @@ class ModelSharder(object): param_replacement = module_description.param_replacement sub_module_replacement = module_description.sub_module_replacement method_replacement = module_description.method_replacement - self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement, - method_replacement, sub_module_replacement) + self._recursive_replace_layer(self.model, + layer_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include) def _recursive_replace_layer( self, @@ -75,6 +80,7 @@ class ModelSharder(object): param_replacement: List[Callable], method_replacement: Dict[str, Callable], sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None, ) -> None: r""" Reverse the replace layer operation @@ -87,23 +93,30 @@ class ModelSharder(object): method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy """ + # released layers are not shardable + can_replace_param_or_layer = include is None or module in include if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): if attr_replacement is not None: self._replace_attr(module, attr_replacement) - if param_replacement is not None: + if param_replacement is not None and can_replace_param_or_layer: self._replace_param(module, param_replacement) if method_replacement is not None: self._replace_method(module, method_replacement) - if sub_module_replacement is not None: + if sub_module_replacement is not None and can_replace_param_or_layer: self._replace_sub_module(module, sub_module_replacement) for name, child in module.named_children(): - self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement, - sub_module_replacement) + self._recursive_replace_layer(child, + origin_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include) def _replace_attr( self, @@ -185,13 +198,15 @@ class ModelSharder(object): setattr_(org_layer, suffix, replace_layer) - def _release_unheld_layers(self) -> None: + def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: r""" Release the unheld layers in the model """ if self.shard_config and self.shard_config.pipeline_stage_manager: held_layers = self.policy.get_held_layers() set_tensors_to_none(self.model, exclude=set(held_layers)) + return set(held_layers) + return None def _materialize(self) -> None: r""" diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py new file mode 100644 index 000000000..a58afac81 --- /dev/null +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -0,0 +1,99 @@ +from contextlib import nullcontext +from typing import Optional + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.fx import is_compatible_with_meta +from colossalai.lazy.lazy_init import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + if init_method == 'lazy': + ctx = LazyInitContext() + else: + ctx = nullcontext() + plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision='bf16') + booster = Booster(plugin=plugin) + with ctx: + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data_iter = iter([data]) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + output_key = list(outputs.keys())[0] + loss = criterion(outputs[output_key]) + return loss + + booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False) + optimizer.step() + + except Exception as e: + return repr(e) + + +@parameterize('init_method', ['none', 'lazy']) +def check_3d_plugin(init_method: str = 'none', early_stop: bool = True): + """check gemini plugin over model zoo + + Args: + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + is_support_meta = is_compatible_with_meta() + if not is_support_meta and init_method == 'lazy': + return + + passed_models = [] + failed_info = {} # (model_name, error) pair + + # TODO(ver217): add more models + for name, (model_fn, data_gen_fn, output_transform_fn, _, + _) in model_zoo.get_sub_registry('transformers_llama_for_casual_lm').items(): + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f'Init method: {init_method}') + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_3d_plugin(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def test_gemini_plugin(early_stop: bool = True): + spawn(run_dist, 4, early_stop=early_stop) + + +if __name__ == '__main__': + test_gemini_plugin(early_stop=False)