From dc003c304c4c1772f62914d0b34b1f2d96d901ab Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 2 Nov 2023 10:21:24 +0800 Subject: [PATCH] [moe] merge moe into main (#4978) * update moe module * support openmoe --- .../plugin/moe_hybrid_parallel_plugin.py | 382 ++++++ colossalai/context/__init__.py | 2 - colossalai/context/moe_context.py | 132 -- .../kernel/triton/llama_act_combine_kernel.py | 185 +++ .../engine/gradient_handler/__init__.py | 2 - colossalai/legacy/initialize.py | 12 - colossalai/moe/__init__.py | 17 + colossalai/moe/_operation.py | 275 ++++ colossalai/moe/checkpoint.py | 274 ++++ colossalai/moe/experts.py | 156 +++ colossalai/moe/layers.py | 361 ++++++ colossalai/moe/load_balance.py | 442 +++++++ .../{nn/loss/loss_moe.py => moe/loss.py} | 9 +- colossalai/moe/manager.py | 162 +++ colossalai/moe/routers.py | 419 +++++++ colossalai/moe/utils.py | 177 +++ colossalai/nn/layer/__init__.py | 1 - colossalai/nn/layer/moe/__init__.py | 21 - colossalai/nn/layer/moe/_operation.py | 171 --- colossalai/nn/layer/moe/checkpoint.py | 40 - colossalai/nn/layer/moe/experts.py | 201 --- colossalai/nn/layer/moe/layers.py | 212 ---- colossalai/nn/layer/moe/routers.py | 235 ---- colossalai/nn/layer/moe/utils.py | 71 -- colossalai/nn/loss/__init__.py | 1 - colossalai/tensor/moe_tensor/__init__.py | 0 colossalai/tensor/moe_tensor/api.py | 137 ++ colossalai/tensor/moe_tensor/moe_info.py | 28 + colossalai/utils/moe.py | 53 - colossalai/zero/low_level/low_level_optim.py | 307 ++++- examples/language/openmoe/README.md | 129 ++ .../openmoe/benchmark/benchmark_cai.py | 296 +++++ .../openmoe/benchmark/benchmark_cai.sh | 78 ++ .../openmoe/benchmark/benchmark_cai_dist.sh | 47 + .../openmoe/benchmark/benchmark_fsdp.py | 139 ++ .../openmoe/benchmark/benchmark_fsdp.sh | 34 + .../language/openmoe/benchmark/hostfile.txt | 2 + examples/language/openmoe/benchmark/utils.py | 126 ++ examples/language/openmoe/infer.py | 57 + examples/language/openmoe/infer.sh | 1 + examples/language/openmoe/model/__init__.py | 0 .../openmoe/model/convert_openmoe_ckpt.py | 224 ++++ .../openmoe/model/convert_openmoe_ckpt.sh | 1 + .../openmoe/model/modeling_openmoe.py | 1113 +++++++++++++++++ .../openmoe/model/openmoe_8b_config.json | 24 + .../openmoe/model/openmoe_base_config.json | 24 + .../language/openmoe/model/openmoe_policy.py | 562 +++++++++ examples/language/openmoe/requirements.txt | 5 + examples/language/openmoe/test_ci.sh | 37 + examples/language/openmoe/train.py | 377 ++++++ examples/language/openmoe/train.sh | 40 + pytest.ini | 2 +- .../triton/test_llama_act_combine.py | 56 + tests/test_moe/moe_utils.py | 169 +++ tests/test_moe/test_grad_handler.py | 65 +- tests/test_moe/test_kernel.py | 50 +- tests/test_moe/test_moe_checkpoint.py | 140 ++- tests/test_moe/test_moe_colo_init.py | 55 - tests/test_moe/test_moe_ep_tp.py | 81 ++ tests/test_moe/test_moe_group.py | 84 +- tests/test_moe/test_moe_hybrid_zero.py | 97 ++ tests/test_moe/test_moe_load_balance.py | 190 +++ tests/test_moe/test_moe_router.py | 41 + tests/test_moe/test_moe_zero_fwd_bwd.py | 105 ++ tests/test_moe/test_moe_zero_init.py | 106 -- tests/test_moe/test_moe_zero_model.py | 70 -- tests/test_moe/test_moe_zero_optim.py | 163 +-- 67 files changed, 7618 insertions(+), 1657 deletions(-) create mode 100644 colossalai/booster/plugin/moe_hybrid_parallel_plugin.py delete mode 100644 colossalai/context/moe_context.py create mode 100644 colossalai/kernel/triton/llama_act_combine_kernel.py create mode 100644 colossalai/moe/__init__.py create mode 100644 colossalai/moe/_operation.py create mode 100644 colossalai/moe/checkpoint.py create mode 100644 colossalai/moe/experts.py create mode 100644 colossalai/moe/layers.py create mode 100644 colossalai/moe/load_balance.py rename colossalai/{nn/loss/loss_moe.py => moe/loss.py} (92%) create mode 100644 colossalai/moe/manager.py create mode 100644 colossalai/moe/routers.py create mode 100644 colossalai/moe/utils.py delete mode 100644 colossalai/nn/layer/moe/__init__.py delete mode 100644 colossalai/nn/layer/moe/_operation.py delete mode 100644 colossalai/nn/layer/moe/checkpoint.py delete mode 100644 colossalai/nn/layer/moe/experts.py delete mode 100644 colossalai/nn/layer/moe/layers.py delete mode 100644 colossalai/nn/layer/moe/routers.py delete mode 100644 colossalai/nn/layer/moe/utils.py create mode 100644 colossalai/tensor/moe_tensor/__init__.py create mode 100644 colossalai/tensor/moe_tensor/api.py create mode 100644 colossalai/tensor/moe_tensor/moe_info.py delete mode 100644 colossalai/utils/moe.py create mode 100644 examples/language/openmoe/README.md create mode 100644 examples/language/openmoe/benchmark/benchmark_cai.py create mode 100755 examples/language/openmoe/benchmark/benchmark_cai.sh create mode 100755 examples/language/openmoe/benchmark/benchmark_cai_dist.sh create mode 100644 examples/language/openmoe/benchmark/benchmark_fsdp.py create mode 100755 examples/language/openmoe/benchmark/benchmark_fsdp.sh create mode 100644 examples/language/openmoe/benchmark/hostfile.txt create mode 100644 examples/language/openmoe/benchmark/utils.py create mode 100644 examples/language/openmoe/infer.py create mode 100644 examples/language/openmoe/infer.sh create mode 100644 examples/language/openmoe/model/__init__.py create mode 100644 examples/language/openmoe/model/convert_openmoe_ckpt.py create mode 100644 examples/language/openmoe/model/convert_openmoe_ckpt.sh create mode 100644 examples/language/openmoe/model/modeling_openmoe.py create mode 100644 examples/language/openmoe/model/openmoe_8b_config.json create mode 100644 examples/language/openmoe/model/openmoe_base_config.json create mode 100644 examples/language/openmoe/model/openmoe_policy.py create mode 100644 examples/language/openmoe/requirements.txt create mode 100644 examples/language/openmoe/test_ci.sh create mode 100644 examples/language/openmoe/train.py create mode 100644 examples/language/openmoe/train.sh create mode 100644 tests/test_infer_ops/triton/test_llama_act_combine.py create mode 100644 tests/test_moe/moe_utils.py delete mode 100644 tests/test_moe/test_moe_colo_init.py create mode 100644 tests/test_moe/test_moe_ep_tp.py create mode 100644 tests/test_moe/test_moe_hybrid_zero.py create mode 100644 tests/test_moe/test_moe_load_balance.py create mode 100644 tests/test_moe/test_moe_router.py create mode 100644 tests/test_moe/test_moe_zero_fwd_bwd.py delete mode 100644 tests/test_moe/test_moe_zero_init.py delete mode 100644 tests/test_moe/test_moe_zero_model.py diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py new file mode 100644 index 000000000..b67642b0d --- /dev/null +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -0,0 +1,382 @@ +import random +from types import MethodType +from typing import Callable, Optional, OrderedDict, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.booster.plugin.hybrid_parallel_plugin import ( + HybridParallelAMPOptimizer, + HybridParallelModule, + HybridParallelNaiveOptimizer, + HybridParallelPlugin, + get_param_info, + init_pipeline_optimizer, +) +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.moe import MoeCheckpintIO +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.zero.low_level import LowLevelZeroOptimizer + +PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, + forced_dtype: Optional[torch.dtype] = None, + moe_extra_dp_process_group: Optional[ProcessGroup] = None, + ): + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.dp_pg = dp_process_group + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__( + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, + moe_extra_dp_process_group=moe_extra_dp_process_group, + ) + + +class MoeHybridParallelPlugin(HybridParallelPlugin): + """ + Plugin for Moe Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + """ + + def __init__( + self, + tp_size: int, + pp_size: int, + extra_dp_size: int = 1, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + use_ep_inside: bool = True, + custom_policy: Policy = None, + ) -> None: + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + + if enable_sequence_parallelism: + assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" + + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + # we change pg mesh to (pp, dp, tp) for better moe performance + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) + + # sync moe in outer dp group, and sync other param in global dp group + if extra_dp_size > 1: + ep_size = self.dp_size // extra_dp_size + if use_ep_inside: + self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size) + self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1) + if dist.get_rank() == 0: + print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}") + else: + self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size) + self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2) + if dist.get_rank() == 0: + print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}") + else: + self.moe_extra_dp_group = None + + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + ) + + self.max_norm = max_norm + + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler( + dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + def get_checkpoint_io(self) -> MoeCheckpintIO: + self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + if not isinstance(model, ModelWrapper): + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule( + model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config, + ) + self.checkpoint_io.link_master_and_working_param( + optimizer.working_to_master_map, optimizer.master_to_working_map + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + ) + else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + pp_process_group=self.pp_group, + moe_extra_dp_process_group=self.moe_extra_dp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index ab57301bb..3e94b7cfe 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,7 +1,5 @@ from .config import Config, ConfigException -# from .moe_context import MOE_CONTEXT - __all__ = [ "Config", "ConfigException", diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py deleted file mode 100644 index 066dfc722..000000000 --- a/colossalai/context/moe_context.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Tuple - -import torch -import torch.distributed as dist - -from colossalai.context.singleton_meta import SingletonMeta -from colossalai.legacy.tensor import ProcessGroup - - -def _check_sanity(): - from colossalai.legacy.core import global_context as gpc - - if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.") - - -class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups.""" - - def __init__(self, ep_size: int, dp_size: int): - _check_sanity() - self.ep_size = ep_size - self.dp_size = dp_size - self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) - self.ep_group = self.pg.tp_process_group() - self.dp_group = self.pg.dp_process_group() - - -class MoeContext(metaclass=SingletonMeta): - """MoE parallel context manager. This class manages different - parallel groups in MoE context and MoE loss in training. - """ - - def __init__(self): - self.world_size = 1 - # Users may want to set maximum expert parallel size smaller than the world size - # since very low bandwidth across nodes may constrain the performance of MoE - # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = 1 - self.min_dp_size = 1 - self.aux_loss = None - self.use_kernel_optim = True - - self.has_setup = False - self._parallel_info_dict = dict() - - @property - def parallel_info_dict(self): - return self._parallel_info_dict - - @property - def is_initialized(self): - return self.has_setup - - def setup(self, seed: int, use_kernel_optim: bool = True): - assert not self.is_initialized, "MoE distributed context shouldn't be set up again" - _check_sanity() - assert torch.cuda.is_available(), "MoE requires to enable CUDA first" - - self.world_size = dist.get_world_size() - - from colossalai.legacy.core import global_context as gpc - - self.max_ep_size = gpc.config.get("max_ep_size", self.world_size) - assert ( - self.world_size % self.max_ep_size == 0 - ), "Maximum expert parallel size must be a factor of the number of GPUs" - self.min_dp_size = self.world_size // self.max_ep_size - - # Enabling kernel optimization may raise error in some cases - # Users can close kernel optimization manually - self.use_kernel_optim = use_kernel_optim - - from .random import moe_set_seed - - moe_set_seed(seed) - self.has_setup = True - - def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: - """Calculate the Data Parallel Group and Expert Parallel Group. - - Parameters - ---------- - num_experts : int - The number experts - - Returns - ------- - int, MoeParallelInfo - number of local experts, the MoeParallelInfo of the current ep_size - """ - - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, ( - "Automatic experts placement dose not not support expert number" - " is not a multiple of ep size or vice versa." - ) - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size - - # Calculate the number of experts for each GPU - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size - - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size - if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) - - return num_local_experts, self.parallel_info_dict[ep_size] - - def set_kernel_not_use(self): - self.use_kernel_optim = False - - def reset_loss(self): - self.aux_loss = 0 - - def add_loss(self, loss): - self.aux_loss += loss - - def get_loss(self): - return self.aux_loss - - -MOE_CONTEXT = MoeContext() diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py new file mode 100644 index 000000000..45996c0dc --- /dev/null +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -0,0 +1,185 @@ +from functools import reduce +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + PRECISION_MAP = { + "fp32": (0, torch.float32), + "fp16": (1, torch.float16), + "bf16": (2, torch.bfloat16), + } + + @triton.jit + def _llama_act_combine_forward( + X_GATE1, + X_GATE2, + X_UP, + Y, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + Y += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up + # Write output + tl.store(Y + cols, y, mask=mask) + + @triton.jit + def _llama_act_combine_backward( + X_GATE1, + X_GATE2, + X_UP, + X_GATE1_GRAD, + X_GATE2_GRAD, + X_UP_GRAD, + Y_GRAD, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + X_GATE1_GRAD += row * stride + X_GATE2_GRAD += row * stride + X_UP_GRAD += row * stride + Y_GRAD += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) + + # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid + x_up_grad = x_gate2_act * x_gate1 + x_gate1_grad = x_gate2_act * x_up + # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)] + # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]} + x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid)) + + # Write output + tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask) + tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask) + tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask) + + class LlamaActCombine(torch.autograd.Function): + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x_gate + x_up (torch.Tensor): (b, l, d) x_up + activation (str): only support swiglu + precision (str): fp32, fp16, bf16 + """ + + @staticmethod + @custom_fwd + def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor: + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x gate + x_up (torch.Tensor): (b, l, d) x up + activation (str): only support swiglu + """ + assert activation == "swiglu", "Only swiglu is supported" + + # split x gate + assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" + x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1) + x_gate1 = x_gate1.contiguous() + x_gate2 = x_gate2.contiguous() + if not x_up.is_contiguous(): + x_up = x_up.contiguous() + # assert shape + assert x_gate1.shape == x_gate2.shape == x_up.shape + + # add ctx for backward + if x_gate.requires_grad: + ctx.save_for_backward(x_gate1, x_gate2, x_up) + + # allocate output + y = torch.empty_like(x_up) + M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x_gate.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # restore setting + ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps + # enqueue kernel + _llama_act_combine_forward[(M,)](x_gate1, + x_gate2, + x_up, + y, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + return y + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]: + # restore from ctx + (x_gate1, x_gate2, x_up) = ctx.saved_tensors + M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps + + # init grad + y_grad = grad_outputs[0] + x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( + x_gate2), torch.empty_like(x_up) + + # enqueue kernel + _llama_act_combine_backward[(M,)](x_gate1, + x_gate2, + x_up, + x_gate1_grad, + x_gate2_grad, + x_up_grad, + y_grad, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) + return x_gate_grad, x_up_grad, None, None diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py index 78928b138..713df5a64 100644 --- a/colossalai/legacy/engine/gradient_handler/__init__.py +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -1,6 +1,5 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler -from ._moe_gradient_handler import MoeGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler @@ -10,6 +9,5 @@ __all__ = [ "DataParallelGradientHandler", "ZeROGradientHandler", "PipelineSharedModuleGradientHandler", - "MoeGradientHandler", "SequenceParallelGradientHandler", ] diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index ce9c62655..4035bd6b5 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -16,7 +16,6 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from colossalai.context import Config, ConfigException -from colossalai.context.moe_context import MOE_CONTEXT from colossalai.interface import OptimizerWrapper from colossalai.legacy.amp import AMP_TYPE, convert_to_amp from colossalai.legacy.amp.naive_amp import NaiveAMPModel @@ -36,7 +35,6 @@ from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param def get_default_parser(): @@ -323,8 +321,6 @@ def initialize( if not use_zero: if is_using_sequence(): sync_model_param(model, ParallelMode.SEQUENCE_DP) - elif MOE_CONTEXT.is_initialized: - sync_moe_model_param(model) elif is_using_ddp(): sync_model_param(model, ParallelMode.DATA) else: @@ -377,14 +373,6 @@ def initialize( "added even though not specified in the configuration", ranks=[0], ) - elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type="MoeGradientHandler")] - if verbose: - logger.info( - "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0], - ) elif is_using_sequence(): model = DDP( model, diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py new file mode 100644 index 000000000..f32e89dfa --- /dev/null +++ b/colossalai/moe/__init__.py @@ -0,0 +1,17 @@ +from .checkpoint import MoeCheckpintIO +from .experts import MLPExperts +from .layers import SparseMLP +from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter +from .utils import NormalNoiseGenerator, UniformNoiseGenerator + +__all__ = [ + "MLPExperts", + "MoeRouter", + "Top1Router", + "Top2Router", + "TopKRouter", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "SparseMLP", + "MoeCheckpintIO", +] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py new file mode 100644 index 000000000..542c63727 --- /dev/null +++ b/colossalai/moe/_operation.py @@ -0,0 +1,275 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.distributed import ProcessGroup + +from colossalai.moe.manager import MOE_MANAGER + +MOE_KERNEL = None + + +def load_moe(): + global MOE_KERNEL + from colossalai.kernel.op_builder import MOEBuilder + + MOE_KERNEL = MOEBuilder().load() + + +class AllGather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.unsqueeze(0), None + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + if not overlap: + dist.all_gather(buffer_list, inputs, group=group) + return outputs, None + else: + handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) + return outputs, handle + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + return ( + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], + None, + None, + ) + + +class ReduceScatter(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.squeeze(0), None + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + if not overlap: + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs, None + else: + handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) + return outputs, handle + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + # TODO: support async backward + return ( + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], + None, + None, + ) + + +class AllToAll(torch.autograd.Function): + """Dispatches input tensor [e, c, h] to all experts by all_to_all_single + operation in torch.distributed. + """ + + @staticmethod + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + if ctx is not None: + ctx.comm_grp = group + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + if dist.get_world_size(group) == 1: + return inputs, None + output = torch.empty_like(inputs) + if not overlap: + dist.all_to_all_single(output, inputs, group=group) + return output, None + else: + handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) + return output, handle + + @staticmethod + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + return ( + AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0], + None, + None, + ) + + +class MoeDispatch(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, tokens, mask, dest_idx, ec): + s = tokens.size(0) + h = tokens.size(1) + dtype = tokens.dtype + + if MOE_KERNEL is None: + load_moe() + if tokens.dtype != torch.float32: + tokens = tokens.to(torch.float32) + expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + if expert_input.dtype != dtype: + expert_input = expert_input.to(dtype) + ctx.save_for_backward(mask, dest_idx) + ctx.s = s + ctx.h = h + ctx.ec = ec + ctx.dtype = dtype + + return expert_input + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + mask, dest_idx = ctx.saved_tensors + if output_grad.dtype != torch.float32: + output_grad = output_grad.to(torch.float32) + d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + if d_tokens.dtype != ctx.dtype: + d_tokens = d_tokens.to(ctx.dtype) + return d_tokens, None, None, None + + +class MoeCombine(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): + assert logits.dtype == torch.float32 + + s = logits.size(0) + e = logits.size(1) + c = ec // e + h = expert_tokens.size(-1) + dtype = expert_tokens.dtype + + if expert_tokens.dtype != torch.float32: + expert_tokens = expert_tokens.to(torch.float32) + if MOE_KERNEL is None: + load_moe() + output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx) + if output.dtype != dtype: + output = output.to(dtype) + + ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) + ctx.s = s + ctx.e = e + ctx.c = c + ctx.h = h + ctx.dtype = dtype + + return output + + @staticmethod + @custom_bwd + def backward(ctx, tokens_grad): + expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + if tokens_grad.dtype != torch.float32: + tokens_grad = tokens_grad.to(torch.float32) + + d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, + mask, dest_idx) + if d_expert.dtype != ctx.dtype: + d_expert = d_expert.to(ctx.dtype) + + return d_expert, d_logits, None, None, None + + +def moe_cumsum(inputs: Tensor, use_kernel: bool = False): + dim0 = inputs.size(0) + flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) + if flag and use_kernel: + if MOE_KERNEL is None: + load_moe() + return MOE_KERNEL.cumsum_sub_one(inputs) + else: + return torch.cumsum(inputs, dim=0) - 1 + + +class MoeInGradScaler(torch.autograd.Function): + """ + Scale the gradient back by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: + if ctx is not None: + ctx.ep_size = ep_size + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.ep_size != 1: + grad = grad * ctx.ep_size + return grad, None + + +class MoeOutGradScaler(torch.autograd.Function): + """ + Scale the gradient by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: + ctx.ep_size = ep_size + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.ep_size != 1: + grad = grad / ctx.ep_size + return grad, None diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py new file mode 100644 index 000000000..386fc2010 --- /dev/null +++ b/colossalai/moe/checkpoint.py @@ -0,0 +1,274 @@ +import logging +import os +from copy import deepcopy +from pathlib import Path +from typing import Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.optim import Optimizer + +from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + save_config_file, + save_state_dict_shards, +) +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor + + +class MoeCheckpintIO(HybridParallelCheckpointIO): + + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + ) -> None: + assert zero_stage in [ + 0, + 1, + 2, + ], f"zero_stage should be 0 or 1 or 2, got {zero_stage}" + super().__init__(dp_group, pp_group, tp_group, zero_stage) + self.parallel = MOE_MANAGER.parallel + + def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: + """ + Preprocess state_dict before loading and slice the state_dict of MOE tensors. + """ + for name, param in state_dict.items(): + if ".experts." in name: + if name in dict(model.named_parameters()): + model_param = dict(model.named_parameters())[name] + if is_moe_tensor(model_param): + ep_rank = get_ep_rank(model_param) + ep_size = get_ep_size(model_param) + expert_num = param.shape[0] // ep_size + assert param.shape[0] % ep_size == 0 + param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] + state_dict[name] = param + dist.barrier() + return state_dict + + def _model_sharder( + self, + state_dict: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + state_dict_sharder = StateDictSharder(size_per_shard) + + for name, param in state_dict.items(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None: + state_dict = torch.load(checkpoint) + state_dict = self.pre_load_model(model, state_dict) + model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False) + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + state_dict = self.pre_load_model(model, state_dict) + missing_keys = [] + + load_state_dict_into_model( + model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True, + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def pre_save_model(self, model: nn.Module) -> dict: + state_dict = model.state_dict() + for name, param in model.named_parameters(): + if ".experts." in name and is_moe_tensor(param): + ep_group = get_ep_group(param) + ep_rank = get_ep_rank(param) + ep_size = get_ep_size(param) + dp_rank = get_dp_rank(param) + if dp_rank == 0: + param = param.data.cuda() + all_param = [deepcopy(param) for _ in range(ep_size)] + # gather param from every ep rank + dist.all_gather(all_param, param, group=ep_group) + if ep_rank == 0: + all_param = torch.cat(all_param, dim=0) + state_dict[name] = all_param.cpu() + if self.pp_size > 1: + if self.dp_rank == 0: + out = [None for _ in range(self.pp_size)] + dist.all_gather_object(out, state_dict, group=self.pp_group) + if self.pp_rank == 0: + new_state_dict = {} + for o in out: + new_state_dict.update(o) + state_dict = new_state_dict + dist.barrier() + return state_dict + + def save_unsharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + ): + state_dict = self.pre_save_model(model) + if dist.get_rank() == 0: + torch.save(state_dict, checkpoint) + dist.barrier() + + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + The filenames are in the form of "pytorch_model.-000XX.bin" + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict = self.pre_save_model(model) + + if dist.get_rank() == 0: + state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + dist.barrier() + + # ======================================================== + # Abstract methods for optimizer loading/saving implementation + # ======================================================== + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + raise NotImplementedError() + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + raise NotImplementedError() + + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + ): + raise NotImplementedError() + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): + raise NotImplementedError() diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py new file mode 100644 index 000000000..3471b2876 --- /dev/null +++ b/colossalai/moe/experts.py @@ -0,0 +1,156 @@ +import math +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn + +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info + +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + + +class MLPExperts(nn.Module): + """ + SparseMLP is a multi-layer perceptron with sparse expert parallel layers. + + Args: + num_experts (int): The number of experts + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP. + activation (optional): The activation function of MLP + drop_rate (float, optional): The drop rate of MLP + gated (bool, optional): Whether to use gated MLP + use_kernel (bool, optional): Whether to use kernel optimization + """ + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + expert_parallel: Optional[str] = None, + activation: Optional[Callable] = None, + drop_rate: Optional[float] = 0, + gated: Optional[bool] = False, + use_kernel: Optional[bool] = False, + ): + super().__init__() + assert expert_parallel in ["EP", "TP", None] + self.expert_parallel = expert_parallel + self.num_total_experts = num_experts + self.gated = gated + self.use_kernel = use_kernel + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + # get expert parallel info + if expert_parallel is not None: + self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( + num_experts, use_tp=True if expert_parallel == "TP" else False) + # get settings for different parallel + self.ep_size = get_ep_size(self) + if expert_parallel == "TP": + intermediate_size = intermediate_size // self.ep_size + num_experts = self.num_total_experts + else: + num_experts = self.num_local_experts + else: + self.num_local_experts = self.num_total_experts + self.ep_size = 1 + + if gated: + self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + else: + self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) + + self.act_name = activation + self.act = get_activation(activation) + self.drop = nn.Dropout(p=drop_rate) + + if expert_parallel is not None: + for param in self.parameters(): + set_moe_tensor_info(param, self.moe_info) + + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + # expert param should be different + if self.expert_parallel is not None: + seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) + else: + seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) + with seed_ctx: + if self.gated: + torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size)) + else: + torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) + + def forward( + self, + x: torch.Tensor, + param_slice: Tuple[slice] = (slice(None),), + use_sparse: bool = True, + ) -> torch.Tensor: + """ + forward: hidden_size --> intermediate_size --> hidden_size + + Args: + x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) + """ + x = MoeInGradScaler.apply(x, self.ep_size) + + e = x.size(1) + h = x.size(-1) + + x = x.transpose(0, 1) + inshape = x.shape + x = x.reshape(e, -1, h) + + if self.use_kernel and use_sparse: + seq_len = x.shape[1] + with torch.no_grad(): + mask = x[:, :, 0] != 0.0 + mask = torch.sum(mask, dim=-1) + x_list = [] + for i in range(e): + x_list.append(x[i, :mask[i]]) + x = x_list + + if self.gated: + x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] + x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] + if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": + x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)] + else: + x = [self.act(x_gate[i]) * x_up[i] for i in range(e)] + else: + x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)] + x = [self.act(x[i]) for i in range(e)] + x = [self.drop(x[i]) for i in range(e)] + x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + + if self.use_kernel and use_sparse: + for i in range(e): + x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) + + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) + x = x.reshape(inshape) + x = x.transpose(0, 1).contiguous() + x = MoeOutGradScaler.apply(x, self.ep_size) + return x diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py new file mode 100644 index 000000000..bd2cefbe9 --- /dev/null +++ b/colossalai/moe/layers.py @@ -0,0 +1,361 @@ +import dataclasses +import math +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe.experts import MLPExperts +from colossalai.moe.load_balance import LoadBalancer +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.routers import MoeRouter, get_router_cls +from colossalai.moe.utils import get_noise_generator +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size + + +class SparseMLP(nn.Module): + """A class for users to create MoE modules in their models. + + Args: + dim_model (int): Hidden dimension of training model + num_experts (int): The number experts + top_k (int, optional): The number of experts for dispatchment of each token + capacity_factor_train (float, optional): Capacity factor in routing during training + capacity_factor_eval (float, optional): Capacity factor in routing during evaluation + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. + 'Jitter' can be found in `Switch Transformer paper`_. + 'Gaussian' can be found in `ViT-MoE paper`_. + drop_tks (bool, optional): Whether drops tokens in evaluation + use_residual (bool, optional): Makes this MoE layer a Residual MoE. + More information can be found in `Microsoft paper`_. + residual_instance (nn.Module, optional): The instance of residual module in Residual MoE + expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer + expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given + expert_args (optional): The args of expert when no instance is given + + .. _Switch Transformer paper: + https://arxiv.org/abs/2101.03961 + .. _ViT-MoE paper: + https://arxiv.org/abs/2106.05974 + .. _Microsoft paper: + https://arxiv.org/abs/2201.05596 + """ + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + router_top_k: int = 1, + router_capacity_factor_train: Optional[float] = 1.25, + router_capacity_factor_eval: Optional[float] = 2.0, + router_min_capacity: Optional[int] = 4, + router_noisy_policy: Optional[str] = None, + router_drop_tks: Optional[bool] = True, + mlp_activation: Optional[str] = None, + mlp_gated: Optional[bool] = False, + enable_load_balance: Optional[bool] = False, + load_balance_tolerance: Optional[float] = 0.1, + load_balance_beam_width: Optional[int] = 8, + load_balance_group_swap_factor: Optional[float] = 0.4, + enable_kernel: Optional[bool] = False, + enable_comm_overlap: Optional[bool] = False, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts = num_experts + self.gated = mlp_gated + self.enable_kernel = enable_kernel + self.enable_comm_overlap = enable_comm_overlap + self.expert_parallel = MOE_MANAGER.get_parallel() + + # moe router + noisy_func = get_noise_generator(router_noisy_policy, num_experts) + router_cls = get_router_cls(router_top_k) + self.topk = router_top_k + self.router: MoeRouter = router_cls( + capacity_factor_train=router_capacity_factor_train, + capacity_factor_eval=router_capacity_factor_eval, + min_capacity=router_min_capacity, + noisy_func=noisy_func, + drop_tks=router_drop_tks, + ) + + # gate + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + + # moe experts + self.experts = MLPExperts( + num_experts=self.num_experts, + expert_parallel=self.expert_parallel, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + activation=mlp_activation, + gated=mlp_gated, + use_kernel=self.enable_kernel, + ) + + # get parallel settings + if self.expert_parallel is not None: + self.ep_group = get_ep_group(self.experts) + self.ep_size = get_ep_size(self.experts) + self.dp_group = get_dp_group(self.experts) + else: + self.ep_group = None + self.dp_group = None + self.num_local_experts = self.experts.num_local_experts + + # load balance + self.enable_load_balance = enable_load_balance + if self.enable_load_balance == True: + self.load_balancer = LoadBalancer( + experts=self.experts, + gate=self.gate_weight, + local_expert_num=self.num_local_experts, + expert_num=self.num_experts, + ep_group=self.ep_group, + dp_group=self.dp_group, + tolerance=load_balance_tolerance, + beam_width=load_balance_beam_width, + group_swap_factor=load_balance_group_swap_factor, + ) + + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size) + """ + # reshape the input tokens + tokens = inputs.reshape(-1, self.hidden_size) + + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # update expert load + if self.enable_load_balance == True: + with torch.no_grad(): + # TODO: optimize computation + expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + # TODO: bincount introduces synchronize, fix it + expert_load = torch.bincount(expert_load.view(-1)) + self.load_balancer.update_load(expert_load) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + + # dispatch_data: (num_experts, capacity, hidden_size) + if self.enable_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) + else: + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # expert_output: (num_groups, num_experts, capacity, hidden_size) + if self.expert_parallel == "EP": + expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap) + elif self.expert_parallel == "TP": + expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap) + elif self.expert_parallel is None: + expert_output = self._local_process(dispatch_data) + else: + raise NotImplementedError("This kind of communication has not been implemented yet.\n" + "Please use Experts build function.") + + if self.enable_kernel: + expert_output = expert_output.reshape(-1, self.hidden_size) + ans = MoeCombine.apply(expert_output, *route_result_list) + else: + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) + + ans = ans.reshape(inputs.shape) + return ans + + def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: + expert_in = expert_in.unsqueeze(0) + expert_out = self.experts(expert_in) + return expert_out + + def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: + """ + Expert Parallel + + Args: + dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: (num_experts, capacity, hidden_size) + """ + if not overlap or dist.get_world_size(self.ep_group) == 1: + expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] + return expert_output + + else: + + @dataclasses.dataclass + class Capsule: + data: torch.Tensor + handle: Any = None + + NUM_CHUNK = 4 + NUM_STAGES = 4 + + assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" + chunk_size = dispatch_data.shape[1] // NUM_CHUNK + input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) + dispatch_data = dispatch_data.reshape(*input_shape) + chunk_data = torch.split(dispatch_data, chunk_size, dim=2) + output = torch.empty_like(dispatch_data) + + offset = 0 + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if expert_out is not None: + expert_out.handle.wait() + output[:, :, offset:offset + chunk_size, :] = expert_out.data + offset += chunk_size + expert_out = None + + # all2all last output + if _expert_out is not None: + expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) + _expert_out = None + + # all2all next input + if 0 <= i < NUM_CHUNK: + _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule(data=self.experts(expert_in.data), handle=None) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None + + return output + + def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: + """ + without overlap: + | C | + | A | | R | + + with overlap: + | C1 || C2 || C3 || C4 | + | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | + + where C is computation, A is all gather, R is reduce scatter. + + Args: + dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: (num_experts, capacity, hidden_size) + """ + if not overlap or dist.get_world_size(self.ep_group) == 1: + expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] + return expert_out + else: + + @dataclasses.dataclass + class Capsule: + data: torch.Tensor + handle: Any + indices: Tuple + + NUM_CHUNK = 4 + NUM_STAGES = 4 + + assert (dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + chunk_size = dispatch_data.shape[0] // NUM_CHUNK + chunk_data = torch.split(dispatch_data, chunk_size, dim=0) + output = torch.empty_like(dispatch_data) + + def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: + return (slice(idx * chunk_size, (idx + 1) * chunk_size),) + + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if expert_out is not None: + expert_out.handle.wait() + output[expert_out.indices] = expert_out.data + expert_out = None + + # reduce scatter last output + if _expert_out is not None: + expert_out = Capsule( + *ReduceScatter.apply(_expert_out.data, self.ep_group, True), + indices=_expert_out.indices, + ) + _expert_out = None + + # all gather next input + if 0 <= i < NUM_CHUNK: + _expert_in = Capsule( + *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), + indices=get_chunk_slice(i, chunk_size), + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + self.experts(expert_in.data, expert_in.indices), + handle=None, + indices=expert_in.indices, + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None + + return output + + +def apply_load_balance(model: nn.Module, optim: Any) -> None: + """ + apply load balance to every experts in the model + """ + + def _apply_recursive(module: nn.Module): + for _, sub_module in module.named_children(): + if isinstance(sub_module, SparseMLP): + if sub_module.enable_load_balance == True: + sub_module.load_balancer.balance_load(optim) + _apply_recursive(sub_module) + + torch.cuda.empty_cache() + _apply_recursive(model) + torch.cuda.empty_cache() diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py new file mode 100644 index 000000000..85c12d73f --- /dev/null +++ b/colossalai/moe/load_balance.py @@ -0,0 +1,442 @@ +from copy import deepcopy +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor, nn +from torch.distributed import ProcessGroup + +from colossalai.cluster import ProcessGroupMesh +from colossalai.moe.experts import MLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.zero.low_level import LowLevelZeroOptimizer + + +class LoadBalancer: + def __init__( + self, + experts: MLPExperts, + gate: nn.Parameter, + local_expert_num: int, + expert_num: int, + ep_group: ProcessGroup, + dp_group: ProcessGroup, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + ) -> None: + self.experts: MLPExperts = experts + self.gate: nn.Parameter = gate + self.moe_ep_group: ProcessGroup = ep_group + self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks + self.moe_dp_group: ProcessGroup = dp_group + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.local_expert_num = local_expert_num + self.expert_num = expert_num + self.local_load = None + # TODO: use a global process group mesh + pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size + global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size) + self.global_dp_group = global_dp_group.get_group_along_axis(1) + self.global_dp_rank = dist.get_rank(self.global_dp_group) + self.global_dp_size = dist.get_world_size(self.global_dp_group) + + def _clear_load(self) -> None: + self.local_load = None + + def _sync_load(self) -> Tensor: + new_load = self.local_load.clone().detach() + # all reduce load between ep group + dist.all_reduce(new_load, group=self.moe_ep_group) + # all reduce load between dp group + dist.all_reduce(new_load, group=self.moe_dp_group) + return new_load + + @staticmethod + def _get_diff_from_avg(data: List, group: int, avg: float) -> float: + return abs(sum(data[group]) / len(data[group]) - avg) + + @staticmethod + def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None: + data[group_i][index_i], data[group_j][index_j] = ( + data[group_j][index_j], + data[group_i][index_i], + ) + + @staticmethod + def _normalize_data(data: List) -> List: + max_value = max(max(sublist) for sublist in data) + data = [[i / max_value for i in sublist] for sublist in data] + return data + + @staticmethod + def _get_swap_loss( + group_swap_factor: float, + swap_list: List, + group_i: int, + index_i: int, + group_j: int, + index_j: int, + ) -> float: + """ + Get swap loss. The swap loss is used to avoid the situation that + the same index is swapped twice and the same group is swapped for multiple times. + """ + swap_loss = 0 + for swap in swap_list: + for group_id, index_id in zip([group_i, group_j], [index_i, index_j]): + # the group has been swapped + if group_id in [swap[0], swap[2]]: + # the index has been swapped + # we want to avoid the situation that the same index is swapped twice + if index_id in [swap[1], swap[3]]: + swap_loss += 1e5 + # the index has not been swapped + # this is acceptable but as less as possible + else: + swap_loss += group_swap_factor + return swap_loss + + @staticmethod + def _check_convergence(data: List, avg: float, tolerance: float): + """ + Check whether the data is converged after swap. + """ + for sublist in data: + if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg: + return False + return True + + def _beam_search( + self, + inputs: Tuple[List, float, List], + beam_width: int, + avg: float, + group_swap_factor: float, + ) -> List: + """ + Beam search for the best swap combination. + Specifically, we swap two elements from two groups and calculate the score. + The score is the difference between the origin group sum and the new group sum. + The larger the score, the better the swap combination. + + Args: + inputs (Tuple): (data, origin_score, swap_list) + beam_width (int): beam width for beam search + avg (float): average value of the data + group_swap_factor (float): group loss for group swap loss + + Returns: + List: results list + """ + data, origin_score, swap_list = inputs + results = [] + group_num = len(data) + group_size = len(data[0]) + origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)] + + for group_num_i in range(group_num): + for group_size_i in range(group_size): + for group_num_j in range(group_num_i + 1, group_num): + for group_size_j in range(group_size): + new_data = deepcopy(data) + # calculate origin group sum + origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j] + # swap data + self._swap_data( + new_data, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + # calculate new group sum + new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg( + new_data, group_num_j, avg + ) + # caculate score + new_score = origin_diff - new_diff + if new_score > 0: + new_score = origin_score + new_score + # get swap loss + swap_loss = self._get_swap_loss( + group_swap_factor, + swap_list, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + new_score = new_score - swap_loss + # update swap list + new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)] + results.append((new_data, new_score, new_swap_list)) + # sort results + results.sort(key=lambda x: x[1], reverse=True) + # select top k results + results = results[:beam_width] + return results + + def _load_to_list(self, load: Tensor) -> List: + load_len = len(load) + assert load_len % self.local_expert_num == 0 + load_list = [] + tmp_list = [] + for i in range(len(load)): + tmp_list.append(float(load[i])) + if (i + 1) % self.local_expert_num == 0: + load_list.append(tmp_list) + tmp_list = [] + return load_list + + def _search_balance( + self, + data: List, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + return_swapped_data: Optional[bool] = False, + ) -> Tuple[List, List]: + """ + Search for the best swap combination to balance the data within the specified tolerance. + And return the balanced data and the swap list. The swap list is used to record the swap. + The swap list is a list of tuples. Each tuple is a swap operation. + + Args: + data (List): expert load list. + E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] + This means there are 4 devices and each devices has 2 experts. + The value is the load of the expert. + tolerance (float): tolerance for balance. + beam_width (int): beam width for beam search. + group_swap_factor (float): group swap factor for group swap loss. + The bigger it is, the less times a group will be swapped. + return_swapped_data (bool): whether to return the swapped data. + + Returns: + Tuple: (balanced data, swap list). + The swap list is a list of tuples. Each tuple is a swap operation. + E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means + the first expert of the first device is swapped with the first expert + of the second device. + """ + norm_data = self._normalize_data(data) + avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data) + results = [(norm_data, 0, [])] + stop_flag = False + + while stop_flag == False: + new_results = [] + best_score = results[0][1] + for i in range(len(results)): + new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor)) + if len(new_results) == 0: + stop_flag = True + break + new_results.sort(key=lambda x: x[1], reverse=True) + new_best_score = new_results[0][1] + if new_best_score == best_score: + stop_flag = True + break + new_results = new_results[:beam_width] + results = new_results + for i in results: + if self._check_convergence(results[0][0], avg, tolerance): + stop_flag = True + break + + swap_list = results[0][2] + if return_swapped_data: + out = deepcopy(data) + for swap in swap_list: + self._swap_data(out, *swap) + return out, swap_list + else: + return swap_list + + @staticmethod + def _swap_expert_single_tensor( + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + ): + # exchange weight + local_weight = weight.data[expert_idx] + new_weight = torch.empty_like(local_weight) + if send_first: + dist.send(local_weight, dst=comm_rank, group=comm_group) + dist.recv(new_weight, src=comm_rank, group=comm_group) + else: + dist.recv(new_weight, src=comm_rank, group=comm_group) + dist.send(local_weight, dst=comm_rank, group=comm_group) + weight.data[expert_idx] = new_weight + + def _swap_expert_param_and_optim( + self, + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + optim: LowLevelZeroOptimizer, + ): + # need to update master and working param if master param exists + # else just update working param + if weight in optim.optim.state: + master_weight_ptr = None + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] + else: + master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] + + # exchange weight + self._swap_expert_single_tensor( + working_weight_ptr, + expert_idx, + comm_group, + send_first, + comm_rank, + ) + if master_weight_ptr is not None: + # TODO: exchange master weight, skip for now + # master weight is shared by dp group + tmp = working_weight_ptr.view(-1).split( + working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group) + )[dist.get_rank(self.moe_dp_group)] + master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype)) + # exchange optim + self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) + self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) + + def _gather_global_dp_group(self, data: Tensor) -> Tensor: + data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)] + dist.all_gather(data_list, data, group=self.global_dp_group) + data_list = torch.cat(data_list, dim=0) + return data_list + + def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None: + """ + Swap moe param and optim. + We use different strategies to swap expert and gate. + For expert, we exchange the param and optim of the expert by p2p. + For gate, we all gather the gate choose the part we want. + + Args: + swap_list (List) + optim (LowLevelZeroOptimizer) + """ + # get all experts weights + local_rank = dist.get_rank(self.moe_ep_group) + if self.experts.gated: + weight_list = [self.experts.wi_up, self.experts.wi_gate] + else: + weight_list = [self.experts.wi] + weight_list.append(self.experts.wo) + + # gate optim should be obtained first + gate_shape = self.gate.shape + # get master weight and optim + master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] + gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] + gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] + # gather + global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape) + global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape) + global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape) + assert ( + self.gate.shape + == global_master_gate_weight.shape + == global_gate_exp_avg.shape + == global_gate_exp_avg_sq.shape + ) + + for swap in swap_list: + source_group, source_idx, target_group, target_idx = swap + source_rank = self.moe_ep_ranks[source_group] + target_rank = self.moe_ep_ranks[target_group] + # exchange expert + if local_rank in [source_group, target_group]: + for weight in weight_list: + if local_rank == source_group: + self._swap_expert_param_and_optim( + weight, + source_idx, + self.moe_ep_group, + True, + target_rank, + optim, + ) + elif local_rank == target_group: + self._swap_expert_param_and_optim( + weight, + target_idx, + self.moe_ep_group, + False, + source_rank, + optim, + ) + # exchange gate + source_expert_pos = source_group * self.local_expert_num + source_idx + target_expert_pos = target_group * self.local_expert_num + target_idx + for gate in [ + self.gate, + global_master_gate_weight, + global_gate_exp_avg, + global_gate_exp_avg_sq, + ]: + origin_source = gate.data[source_expert_pos].clone().detach() + origin_target = gate.data[target_expert_pos].clone().detach() + gate.data[source_expert_pos], gate.data[target_expert_pos] = ( + origin_target, + origin_source, + ) + + # update gate + global_master_gate_weight = global_master_gate_weight.view(-1).split( + global_master_gate_weight.numel() // self.global_dp_size + )[self.global_dp_rank] + master_gate_weight.data.copy_(global_master_gate_weight) + global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[ + self.global_dp_rank + ] + gate_exp_avg.data.copy_(global_gate_exp_avg) + global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split( + global_gate_exp_avg_sq.numel() // self.global_dp_size + )[self.global_dp_rank] + gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq) + + @torch.no_grad() + def update_load(self, load: Tensor) -> None: + if len(load) != self.expert_num: + padding_size = self.expert_num - len(load) + padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device) + load = torch.cat((load, padding), dim=0) + if self.local_load is None: + self.local_load = load + else: + self.local_load += load + + @torch.no_grad() + def balance_load(self, optim: LowLevelZeroOptimizer) -> None: + # prepare load + load = self._sync_load() + load = self._load_to_list(load) + # search balance + swap_list = self._search_balance(load) + if dist.get_rank() == 0: + if len(swap_list) > 0: + print(f"[Load Balance] Applying expert swap...") + else: + print(f"[Load Balance] Invalid swap, skip...") + # swap expert and gate + self._swap_moe_param(swap_list, optim) + # clear load + self._clear_load() diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/moe/loss.py similarity index 92% rename from colossalai/nn/loss/loss_moe.py rename to colossalai/moe/loss.py index 40cea788c..75624510b 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/moe/loss.py @@ -1,11 +1,9 @@ import torch.nn as nn from torch.nn.modules.loss import _Loss -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.registry import LOSSES +from colossalai.moe.manager import MOE_MANAGER -@LOSSES.register_module class MoeCrossEntropyLoss(_Loss): r"""torch.nn.CrossEntropyLoss added with auxiliary loss. @@ -45,11 +43,10 @@ class MoeCrossEntropyLoss(_Loss): `Cross_entropy `_. """ main_loss = self.loss(*args) - aux_loss = MOE_CONTEXT.get_loss() + aux_loss = MOE_MANAGER.get_loss() return main_loss + self.aux_weight * aux_loss -@LOSSES.register_module class MoeLoss(_Loss): """A wrapper class for any loss module to add with auxiliary loss. @@ -77,5 +74,5 @@ class MoeLoss(_Loss): The ``args`` and ``kwargs`` may include different parameters varying with different loss function. """ main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_CONTEXT.get_loss() + aux_loss = MOE_MANAGER.get_loss() return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py new file mode 100644 index 000000000..f237ea134 --- /dev/null +++ b/colossalai/moe/manager.py @@ -0,0 +1,162 @@ +from typing import Tuple + +import torch +import torch.distributed as dist + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.moe_tensor.api import get_moe_info +from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo + + +class MoeManager(metaclass=SingletonMeta): + """MoE manager. This class manages different + parallel groups in MoE context and MoE loss in training. + """ + + def __init__(self): + self.parallel = None + self.seed = None + self.mode = None + self.use_ep_inside = None + self.world_size = None + self._parallel_info_dict = dict() + + # router + self.router_aux_loss = [] + self.router_z_loss = [] + + # fixed mode + self.pp_size = None + self.dp_size = None + self.ep_size = None + + # dynamic mode + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = None + + self.has_setup = False + + @property + def parallel_info_dict(self): + return self._parallel_info_dict + + @property + def is_initialized(self): + return self.has_setup + + def setup( + self, + seed: int, + parallel: str = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0, + use_ep_inside: bool = True, + ) -> None: + """ + Setup MoE distributed context. + + Args: + seed (int): Random seed. Defaults to 42. + use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True. + parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None. + mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic". + In fixed mode, the ep size and dp size is fixed. + In dynamic mode, the ep size and dp size will be changed according to num experts. + max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8. + fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. + fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. + fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. + use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ + assert (not self.is_initialized), "MoE distributed context shouldn't be set up again" + assert torch.cuda.is_available(), "MoE requires to enable CUDA first" + + self.seed = seed + dist.get_rank() + self.parallel = parallel + self.use_ep_inside = use_ep_inside + self.world_size = dist.get_world_size() + + # init by mode + self.mode = mode + assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" + if self.mode == "dynamic": + self.max_ep_size = min(max_ep_size, self.world_size) + else: + assert (fixed_dp_size > 0 and fixed_ep_size > 0 + and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0" + assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) + and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int" + self.ep_size = fixed_ep_size + self.dp_size = fixed_dp_size + self.pp_size = fixed_pp_size + + self.has_setup = True + + def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: + """Calculate the Data Parallel Group and Expert Parallel Group. + + Parameters + ---------- + num_experts : int + The number experts + + Returns + ------- + int, MoeParallelInfo + number of local experts, the MoeParallelInfo of the current ep_size + """ + + if self.mode == "dynamic": + gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater + lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less + assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa.") + dp_size = 1 if gt_flag else self.world_size // num_experts + ep_size = min(self.world_size // dp_size, self.max_ep_size) + dp_size = self.world_size // ep_size + pp_size = 1 + else: + dp_size = self.dp_size + ep_size = self.ep_size + pp_size = self.pp_size + + # Calculate the number of experts for each GPU + if use_tp: + num_local_experts = num_experts + else: + if self.mode == "dynamic": + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + else: + num_local_experts = num_experts // ep_size + + if not (ep_size in self.parallel_info_dict): + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside) + if dist.get_rank() == 0: + if self.use_ep_inside: + print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}") + else: + print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}") + + return num_local_experts, self.parallel_info_dict[ep_size] + + def reset_loss(self): + self.router_aux_loss, self.router_z_loss = [], [] + + def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0): + self.router_aux_loss.append(aux_loss) + self.router_z_loss.append(z_loss) + + def get_loss(self): + cur_loss = self.router_aux_loss, self.router_z_loss + return cur_loss + + def get_parallel(self): + return self.parallel + + +MOE_MANAGER = MoeManager() diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py new file mode 100644 index 000000000..7960a74d4 --- /dev/null +++ b/colossalai/moe/routers.py @@ -0,0 +1,419 @@ +import math +from abc import ABC +from typing import Callable, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +from colossalai.moe._operation import moe_cumsum +from colossalai.moe.manager import MOE_MANAGER +from colossalai.utils import get_current_device + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + use_kernel: bool = False): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._aux_loss = None + self._z_loss = None + self.use_kernel = use_kernel + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return int(capacity) + + def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs: Probability assigned to each expert per token. Shape: + [num_groups, tokens_per_group, num_experts]. + expert_indices: [num_groups, tokens_per_group, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + """ + assert self._aux_loss is None + if router_probs.dim() == expert_indices.dim() == 2: + router_probs = router_probs.unsqueeze(0) + expert_indices = expert_indices.unsqueeze(0) + assert router_probs.dim() == expert_indices.dim() == 3, \ + "router_probs must be 3D tensor and expert_indices must be 4D tensor" + + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_indices, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = expert_mask.max(dim=-2)[0] + + tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) + router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) + aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) + self._aux_loss = aux_loss + + def set_z_loss(self, router_logits: torch.Tensor): + """Compute router z-loss. + + The router z-loss was introduced in Designing Effective Sparse Expert Models + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [num_groups, tokens_per_group, num_experts] router logits. + """ + assert self._z_loss is None + if router_logits.dim() == 2: + router_logits = router_logits.unsqueeze(0) + assert router_logits.dim() == 3, "router_logits must be 3D tensor" + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) + self._z_loss = z_loss + + def pop_router_loss(self) -> torch.Tensor: + assert self._aux_loss is not None + MOE_MANAGER.add_loss(self._aux_loss, self._z_loss) + self._aux_loss = None + self._z_loss = None + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + function can be found in the paper about Switch Transformer of Google. + + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Optional[Callable] = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). + + Returns: + 1. use_kernel is False: + The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). + The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). + 2. use_kernel is True: + ... + """ + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + probs = F.softmax(inputs, dim=-1) + num_experts = probs.size(-1) + capacity = self.get_capacity(inputs.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate router loss + self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() + + if not self.training and not self.drop_tks and ep_group is not None: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) + elif self.select_policy == "first": + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return probs, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * probs.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + function can be found in the paper about ViT-MoE. + + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). + + Returns: + 1. use_kernel is False: + The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). + The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). + 2. use_kernel is True: + ... + """ + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + probs = F.softmax(inputs, dim=-1) + num_experts = probs.size(-1) + capacity = self.get_capacity(inputs.shape) + + top1_idx = torch.argmax(probs, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = probs.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = (mask1 + mask2) # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 + + # caculate loss + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() + + if not self.training and not self.drop_tks and ep_group is not None: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return probs, mask, dest_idx, num_experts * capacity + else: + # >>> original code + # weight1 = mask1 * probs.type_as(inputs) + # weight2 = mask2 * probs.type_as(inputs) + # rank1_sc = F.one_hot(rank1, num_classes=capacity) + # rank2_sc = F.one_hot(rank2, num_classes=capacity) + + # cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + # cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + # cb_weight = cb_weight1 + cb_weight2 + # sec_mask = cb_weight.bool() + + weight1 = mask1 * probs.type_as(inputs) + weight2 = mask2 * probs.type_as(inputs) + + cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) + sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) + indices = torch.arange(0, inputs.shape[0], device=inputs.device) + cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] + cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]] + sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] + sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] + + return cb_weight, sec_mask + + +class TopKRouter(MoeRouter): + """Masked matmul router using tokens choose top-k experts assignment. + + NOTE: this is modified from flaxformer. + This router uses the same mechanism as in Switch Transformer + (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are + sorted by router_probs and then routed to their choice of expert until the + expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. + """ + + def __init__(self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True): + super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, + drop_tks) + + def forward( + self, + router_probs: torch.Tensor, + expert_capacity: int, + ) -> Tuple: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + # TODO: add parallel group + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(router_probs, self.k_value) + + self.set_aux_loss(router_probs, expert_index, num_experts) + self.pop_router_loss() + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = torch.transpose(expert_index, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = torch.transpose(token_priority, 1, 2) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, dim=2)[0] + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) + + return combine_array, dispatch_mask + + +def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: + if not grouped: + if top_k == 1: + return Top1Router + elif top_k == 2: + return Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + else: + return TopKRouter diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py new file mode 100644 index 000000000..0938e4206 --- /dev/null +++ b/colossalai/moe/utils.py @@ -0,0 +1,177 @@ +import contextlib +from typing import Any, Callable, Dict, List + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor +from colossalai.utils import get_current_device + + +class ForceFP32Parameter(torch.nn.Parameter): + + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logits tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logits tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + return F.softmax(logit, dim=dim, detype=torch.float32) + + +def get_noise_generator(noise_type: str, num_experts: int) -> Callable: + if noise_type is None: + return None + elif noise_type == "Jitter": + noisy_func = UniformNoiseGenerator() + elif noise_type == "Gaussian": + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + return noisy_func + + +def get_activation(act: str) -> Callable: + if act is None or act == "relu": + return torch.nn.ReLU() + elif act == "gelu": + return torch.nn.GELU() + elif act == "swiglu": + return SwiGLU + else: + raise NotImplementedError("Unsupported activation function") + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) + + +@contextlib.contextmanager +def skip_init(): + """ + skip param random init + """ + + def _skip_init(*args, **kwargs): + pass + + init_func = { + "constant_": torch.nn.init.constant_, + "uniform_": torch.nn.init.uniform_, + "normal_": torch.nn.init.normal_, + "kaiming_uniform_": torch.nn.init.kaiming_uniform_, + "kaiming_normal_": torch.nn.init.kaiming_normal_, + "xavier_normal_": torch.nn.init.xavier_normal_, + "xavier_uniform_": torch.nn.init.xavier_uniform_, + "trunc_normal_": torch.nn.init.trunc_normal_, + } + + for method_name, original_init in init_func.items(): + setattr(torch.nn.init, method_name, _skip_init) + + yield + + for method_name, original_init in init_func.items(): + setattr(torch.nn.init, method_name, original_init) + + return + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not is_moe_tensor(param): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = get_ep_size(param) + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + param_dict = get_moe_epsize_param_dict(model) + + # synchronize the parameters whose dp_group is the whole world + if 1 in param_dict: + for param in param_dict[1]: + dist.broadcast(param, src=0) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: + for param in param_dict[ep_size]: + src_rank = get_dp_group_ranks(param)[0] + dist.broadcast(param, src=src_rank, group=get_dp_group(param)) + + +def set_moe_args(config: Any, args: dict): + for k, v in args.items(): + setattr(config, k, v) diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index 9aeab9f44..16281fe0b 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,2 +1 @@ -# from .moe import * from .utils import * diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py deleted file mode 100644 index 6a5ccff51..000000000 --- a/colossalai/nn/layer/moe/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from .checkpoint import load_moe_model, save_moe_model -from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, MoeModule -from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts - -__all__ = [ - "Experts", - "FFNExperts", - "TPExperts", - "Top1Router", - "Top2Router", - "MoeLayer", - "NormalNoiseGenerator", - "UniformNoiseGenerator", - "build_ffn_experts", - "MoeModule", - "MoeRouter", - "save_moe_model", - "load_moe_model", -] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py deleted file mode 100644 index 2f0b7e436..000000000 --- a/colossalai/nn/layer/moe/_operation.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -from torch import Tensor -from torch.distributed import ProcessGroup - -COL_MOE_KERNEL_FLAG = False - -try: - from colossalai._C import moe -except: - moe = None - - -def build_moe_if_not_prebuilt(): - # load moe kernel during runtime if not pre-built - global moe - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - - moe = MOEBuilder().load() - - -class AllGather(torch.autograd.Function): - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe - - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - - moe = MOEBuilder().load() - - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.unsqueeze(0) - - buffer_shape = (comm_size,) + inputs.shape - outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - dist.all_gather(buffer_list, inputs, group=group) - return outputs - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None - - -class ReduceScatter(torch.autograd.Function): - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.squeeze(0) - - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - - output_shape = inputs.shape[1:] - outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - dist.reduce_scatter(outputs, buffer_list, group=group) - return outputs - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllGather.forward(None, grad_outputs, ctx.comm_grp), None - - -class AllToAll(torch.autograd.Function): - """Dispatches input tensor [e, c, h] to all experts by all_to_all_single - operation in torch.distributed. - """ - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - if dist.get_world_size(group) == 1: - return inputs - output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output - - @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None - - -class MoeDispatch(torch.autograd.Function): - @staticmethod - def forward(ctx, tokens, mask, dest_idx, ec): - s = tokens.size(0) - h = tokens.size(1) - - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - - expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) - - ctx.save_for_backward(mask, dest_idx) - ctx.s = s - ctx.h = h - ctx.ec = ec - - return expert_input - - @staticmethod - def backward(ctx, output_grad): - mask, dest_idx = ctx.saved_tensors - d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) - return d_tokens, None, None, None - - -class MoeCombine(torch.autograd.Function): - @staticmethod - def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): - assert logits.dtype == torch.float32 - - s = logits.size(0) - e = logits.size(1) - c = ec // e - h = expert_tokens.size(-1) - - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - - fp16_flag = expert_tokens.dtype == torch.float16 - cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens - ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) - output = ctokens.to(torch.float16) if fp16_flag else ctokens - - ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) - ctx.s = s - ctx.e = e - ctx.c = c - ctx.h = h - ctx.fp16_flag = fp16_flag - - return output - - @staticmethod - def backward(ctx, tokens_grad): - expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad - cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) - d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert - - return d_expert, d_logits, None, None, None - - -def moe_cumsum(inputs: Tensor): - dim0 = inputs.size(0) - flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and COL_MOE_KERNEL_FLAG: - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - return moe.cumsum_sub_one(inputs) - else: - return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py deleted file mode 100644 index adad19d58..000000000 --- a/colossalai/nn/layer/moe/checkpoint.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn as nn - -from .experts import MoeExperts - - -def save_moe_model(model: nn.Module, save_path: str): - state_dict = model.state_dict() - if dist.get_rank() == 0: - torch.save(state_dict, save_path) - dist.barrier() - - -def load_moe_model(model: nn.Module, load_path: str): - state_dict = torch.load(load_path) - - for prefix, module in model.named_modules(): - if prefix.endswith(".moe_layer.experts"): - # this module should be an Experts instance - assert isinstance(module, MoeExperts) - - ep_rank = dist.get_rank(module.dist_info.ep_group) - num_local = module.num_local_experts - for i in range(num_local): - expert_id = ep_rank * num_local + i - for name, _ in module.experts[i].named_parameters(): - cur_key = f"{prefix}.experts.{i}.{name}" - param_key = f"{prefix}.experts.{expert_id}.{name}" - load_param = state_dict[param_key] - state_dict[cur_key] = load_param - - for name, _ in module.experts[0].named_parameters(): - pop_pre = f"{prefix}.experts." - pop_suf = f".{name}" - for i in range(num_local, module.num_total_experts): - pop_key = f"{pop_pre}{i}{pop_suf}" - state_dict.pop(pop_key) - - model.load_state_dict(state_dict) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py deleted file mode 100644 index 4b2ecb241..000000000 --- a/colossalai/nn/layer/moe/experts.py +++ /dev/null @@ -1,201 +0,0 @@ -import math -from copy import deepcopy -from typing import Type - -import torch -import torch.distributed as dist -import torch.nn as nn - -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.context import ParallelMode, seed -from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator -from colossalai.utils import get_current_device - - -class MoeExperts(nn.Module): - """Basic class for experts in MoE. It stores what kind of communication experts use - to exchange tokens, how many experts in a single GPU and parallel information such as - expert parallel size, data parallel size and their distributed communication groups. - """ - - def __init__(self, comm_name: str, num_experts: int): - super().__init__() - assert comm_name in { - "all_to_all", - "all_gather", - }, "This kind of communication has not been implemented yet.\n Please use Experts build function." - self.comm_name = comm_name - self.num_total_experts = num_experts - # Get the configuration of experts' deployment and parallel information from moe context - self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) - - -@no_shard_zero_decrator(is_replicated=False) -class Experts(MoeExperts): - """A wrapper class to create experts. It will create E experts across the - moe model parallel group, where E is the number of experts. Every expert - is a instance of the class, 'expert' in initialization parameters. - - Args: - expert_cls (:class:`torch.nn.Module`): The class of all experts - num_experts (int): The number of experts - expert_args: Args used to initialize experts, the args could be found in corresponding expert class - """ - - def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): - super().__init__("all_to_all", num_experts) - - # Use seed to make every expert different from others - with seed(ParallelMode.TENSOR): - self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) - - # Attach parallel information for all parameters in Experts - for exp in self.experts: - for param in exp.parameters(): - param.__setattr__("moe_info", self.dist_info) - - def forward(self, inputs: torch.Tensor): - # Split inputs for each expert - expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) - expert_output = [] - - # Get outputs from each expert - for i in range(self.num_local_experts): - expert_output.append(self.experts[i](expert_input[i])) - - # Concatenate all outputs together - output = torch.cat(expert_output, dim=1).contiguous() - return output - - def state_dict(self, destination=None, prefix="", keep_vars=False): - assert keep_vars == False, "Only support keep_vars=False now" - dp_rank = dist.get_rank(self.dist_info.dp_group) - ep_rank = dist.get_rank(self.dist_info.ep_group) - submodule_dict = dict() - example_submodule = None - for name, subm in self.experts.named_modules(): - if subm is self.experts: - continue - module_number = self.num_local_experts * ep_rank + int(name) - submodule_dict[module_number] = subm - example_submodule = subm - - if dp_rank == 0: - local_prefix = prefix + "experts." - buffer_module = deepcopy(example_submodule) - for i in range(self.num_total_experts): - source_rank = i // self.num_local_experts - current_prefix = local_prefix + str(i) + "." - comm_module = submodule_dict.get(i, buffer_module) - for name, param in comm_module.named_parameters(): - dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) - if ep_rank == 0: - destination[current_prefix + name] = param.data.cpu() - - dist.barrier() - - -class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts.""" - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_to_all", num_experts) - - self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - for param in self.parameters(): - param.__setattr__("moe_info", self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] - el = inputs.size(1) - h = inputs.size(-1) - - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(el, -1, h) - - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] - - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs - - -class TPExperts(MoeExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divide by maximum expert parallel size or - maximum expert parallel size can't be divide by the number of experts. - """ - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - - assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size" - - p_ff = d_ff // MOE_CONTEXT.max_ep_size - - self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - self.w1.__setattr__("moe_info", self.dist_info) - self.w2.__setattr__("moe_info", self.dist_info) - self.b1.__setattr__("moe_info", self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] - e = inputs.size(1) - h = inputs.size(-1) - - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(e, -1, h) - - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] - - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py deleted file mode 100644 index 23d483e6a..000000000 --- a/colossalai/nn/layer/moe/layers.py +++ /dev/null @@ -1,212 +0,0 @@ -import math -from typing import Optional, Tuple, Type - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator -from colossalai.nn.layer.moe._operation import ( - COL_MOE_KERNEL_FLAG, - AllGather, - AllToAll, - MoeCombine, - MoeDispatch, - ReduceScatter, -) -from colossalai.nn.layer.moe.experts import Experts, MoeExperts -from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router -from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator -from colossalai.utils import get_current_device - - -@no_shard_zero_decrator(is_replicated=True) -class MoeLayer(nn.Module): - """A MoE layer, that puts its input tensor to its gate and uses the output logits - to router all tokens, is mainly used to exchange all tokens for every expert across - the moe tensor group by all to all communication. Then it will get the output of all - experts and exchange the output. At last returns the output of the moe system. - - Args: - dim_model (int): Dimension of model. - num_experts (int): The number of experts. - router (MoeRouter): Instance of router used in routing. - experts (MoeExperts): Instance of experts generated by Expert. - """ - - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): - super().__init__() - self.d_model = dim_model - self.num_experts = num_experts - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) - self.router: MoeRouter = router - self.experts: MoeExperts = experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = experts.dist_info.ep_group - self.ep_size = experts.dist_info.ep_size - self.num_local_experts = experts.num_local_experts - - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - - def a2a_process(self, dispatch_data: torch.Tensor): - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def tp_process(self, dispatch_data: torch.Tensor): - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out - - def forward(self, inputs: torch.Tensor) -> Tuple: - # reshape the input tokens - tokens = inputs.reshape(-1, self.d_model) - - # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) - - # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) - - if self.use_kernel: - dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) - dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) - else: - sec_mask_f = route_result_list[1].type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - # dispatch_data [e, c, h] - if self.experts.comm_name == "all_to_all": - expert_output = self.a2a_process(dispatch_data) - elif self.experts.comm_name == "all_gather": - expert_output = self.tp_process(dispatch_data) - else: - raise NotImplementedError( - "This kind of communication has not been implemented yet.\n Please use Experts " "build function." - ) - # expert_output [e, c, h] - if self.use_kernel: - expert_output = expert_output.reshape(-1, self.d_model) - ans = MoeCombine.apply(expert_output, *route_result_list) - else: - combine_weights = route_result_list[0].type_as(inputs) - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - ans = torch.matmul(combine_weights, expert_output) - - ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux - - -class MoeModule(nn.Module): - """A class for users to create MoE modules in their models. - - Args: - dim_model (int): Hidden dimension of training model - num_experts (int): The number experts - top_k (int, optional): The number of experts for dispatchment of each token - capacity_factor_train (float, optional): Capacity factor in routing during training - capacity_factor_eval (float, optional): Capacity factor in routing during evaluation - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. - 'Jitter' can be found in `Switch Transformer paper`_. - 'Gaussian' can be found in `ViT-MoE paper`_. - drop_tks (bool, optional): Whether drops tokens in evaluation - use_residual (bool, optional): Makes this MoE layer a Residual MoE. - More information can be found in `Microsoft paper`_. - residual_instance (nn.Module, optional): The instance of residual module in Residual MoE - expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer - expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given - expert_args (optional): The args of expert when no instance is given - - .. _Switch Transformer paper: - https://arxiv.org/abs/2101.03961 - .. _ViT-MoE paper: - https://arxiv.org/abs/2106.05974 - .. _Microsoft paper: - https://arxiv.org/abs/2201.05596 - """ - - def __init__( - self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args, - ): - super().__init__() - - noisy_func = None - if noisy_policy is not None: - if noisy_policy == "Jitter": - noisy_func = UniformNoiseGenerator() - elif noisy_policy == "Gaussian": - noisy_func = NormalNoiseGenerator(num_experts) - else: - raise NotImplementedError("Unsupported input noisy policy") - - if top_k == 1: - moe_router_cls = Top1Router - elif top_k == 2: - moe_router_cls = Top2Router - else: - raise NotImplementedError("top_k > 2 is not supported yet") - - self.moe_router = moe_router_cls( - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - self.use_residual = use_residual - if use_residual: - if residual_instance is not None: - self.residual_module = residual_instance - else: - assert expert_cls is not None, "Expert class can't be None when residual instance is not given" - self.residual_module = expert_cls(**expert_args) - - with no_shard_zero_context(): - self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) - - if expert_instance is not None: - my_experts = expert_instance - else: - assert expert_cls is not None, "Expert class can't be None when experts instance is not given" - my_experts = Experts(expert_cls, num_experts, **expert_args) - - self.moe_layer = MoeLayer( - dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts - ) - - def forward(self, inputs: torch.Tensor): - moe_output, l_aux = self.moe_layer(inputs) - - if self.use_residual: - residual_output = self.residual_module(inputs) - combine_coef = self.residual_combine(inputs) - combine_coef = F.softmax(combine_coef, dim=-1) - output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] - else: - output = moe_output - - return output, l_aux diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py deleted file mode 100644 index 7ba83b278..000000000 --- a/colossalai/nn/layer/moe/routers.py +++ /dev/null @@ -1,235 +0,0 @@ -import math -from abc import ABC -from typing import Callable, Optional - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed import ProcessGroup - -from colossalai.nn.layer.moe._operation import moe_cumsum -from colossalai.utils import get_current_device - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__( - self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True, - ): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._routing_loss = None - - def get_capacity(self, logits_shape): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss - - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) - ).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = mask1 + mask2 # loss: [s, e] - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py deleted file mode 100644 index 4f31dd557..000000000 --- a/colossalai/nn/layer/moe/utils.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn.functional as F - -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils import get_current_device - -from .experts import FFNExperts, TPExperts - - -class ForceFP32Parameter(torch.nn.Parameter): - def half(self, memory_format=None): - return self.data.clone() - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logits tensor. - - All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where - `E = the number of experts`. - - Args: - num_experts (int): The number of experts. - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), - ).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class UniformNoiseGenerator: - """Generates a random noisy mask for logits tensor. - copied from mesh tensorflow: - Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. - Makes models more resilient to rounding errors introduced by bfloat16. - This seems particularly important for logits. - - Args: - eps (float, optional): Epsilon in generator, defaults 1e-2. - """ - - def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, device=get_current_device()), - ).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.uniform(inputs.shape) - return inputs * noisy - - -def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 7c6fb099d..e69de29bb 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1 +0,0 @@ -# from .loss_moe import MoeCrossEntropyLoss, MoeLoss diff --git a/colossalai/tensor/moe_tensor/__init__.py b/colossalai/tensor/moe_tensor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py new file mode 100644 index 000000000..c9efec63f --- /dev/null +++ b/colossalai/tensor/moe_tensor/api.py @@ -0,0 +1,137 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from .moe_info import MoeParallelInfo + + +def is_moe_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a moe tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a moe tensor. + """ + return hasattr(tensor, "moe_info") + + +def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None: + """ + Set moe info for the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be set. + moe_info (dict): The moe info to be set. + + """ + tensor.__setattr__("moe_info", moe_info) + + +def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: + """ + Get moe info for the given tensor. + + Args: + ep_size (int): The expert parallel size. + dp_size (int): The data parallel size. + pp_size (int): The pipeline parallel size. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. + + Returns: + dict: The moe info of the given tensor. + """ + return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size) + + +def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: + """ + Get the expert parallel group of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + torch.distributed.ProcessGroup: The expert parallel group of the given tensor. + """ + return tensor.moe_info.ep_group + + +def get_ep_size(tensor: torch.Tensor) -> int: + """ + Get the expert parallel size of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel size of the given tensor. + """ + return tensor.moe_info.ep_size + + +def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: + """ + Get the data parallel group of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + torch.distributed.ProcessGroup: The data parallel group of the given tensor. + """ + return tensor.moe_info.dp_group + + +def get_ep_rank(tensor: torch.Tensor) -> int: + """ + Get the expert parallel rank of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel rank of the given tensor. + """ + return dist.get_rank(get_ep_group(tensor)) + + +def get_dp_rank(tensor: torch.Tensor) -> int: + """ + Get the data parallel rank of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel rank of the given tensor. + """ + return dist.get_rank(get_dp_group(tensor)) + + +def get_ep_group_ranks(tensor: torch.Tensor) -> int: + """ + Get the expert parallel group ranks of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel group ranks of the given tensor. + """ + return tensor.moe_info.ep_group_ranks + + +def get_dp_group_ranks(tensor: torch.Tensor) -> int: + """ + Get the data parallel group ranks of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel group ranks of the given tensor. + """ + return tensor.moe_info.dp_group_ranks diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py new file mode 100644 index 000000000..5097ac104 --- /dev/null +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -0,0 +1,28 @@ +from colossalai.cluster import ProcessGroupMesh + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups.""" + + def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1): + """ + init MoeParallelInfo with ep_size, dp_size and pp_size + + Args: + ep_size (int): expert parallel size + dp_size (int): data parallel (zero) size + pp_size (int, optional): pipeline parallel size. Defaults to 1. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ + self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size + if ep_inside: + self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + else: + self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size) + + self.ep_group = self.pg.get_group_along_axis(self.ep_axis) + self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) + self.dp_group = self.pg.get_group_along_axis(self.dp_axis) + self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py deleted file mode 100644 index 1b75448bd..000000000 --- a/colossalai/utils/moe.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Dict, List - -import torch.distributed as dist -import torch.nn as nn - -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.context import ParallelMode -from colossalai.legacy.core import global_context as gpc -from colossalai.legacy.utils import is_using_ddp - - -def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: - """Returns a parameter dictionary, the key of which is the expert parallel - size of every parameter. Since the parameters in data parallelism is replicated - in each GPU, we set their ep_size to 1. - - Args: - model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. - """ - epsize_param_dict = dict() - for param in model.parameters(): - if not hasattr(param, "moe_info"): - ep_size = 1 # set ep_size to 1 for dp parameters - else: - ep_size = param.moe_info.ep_size - if ep_size not in epsize_param_dict: - epsize_param_dict[ep_size] = [] - epsize_param_dict[ep_size].append(param) - - return epsize_param_dict - - -def sync_moe_model_param(model: nn.Module): - """Make sure model parameters are consistent in MoE parallel context. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - if is_using_ddp(): - param_dict = get_moe_epsize_param_dict(model) - - # synchronize the parameters whose dp_group is the whole world - if 1 in param_dict: - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in param_dict[1]: - dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in param_dict: - # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) - for param in param_dict[ep_size]: - dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e6974a676..932053dd1 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist import torch.nn as nn from torch import Tensor, inf +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -18,6 +19,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger +from colossalai.tensor.moe_tensor.api import is_moe_tensor # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -75,6 +77,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm forced_dtype: Optional[torch.dtype] = None, + moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -95,6 +98,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) + # extra dp + # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. + # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. + # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. + # And moe working and master param are split by extra dp pg. + self.moe_extra_dp_pg = moe_extra_dp_process_group + if self.moe_extra_dp_pg is not None: + self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) + self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) + # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() @@ -126,6 +139,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) self._bucket_store = BucketStore(self.dp_pg) + # moe param should not be stored in working_groups + # because they have different parallel strategy + # so we need to store them separately in param_groups + # instead of working_groups + moe_params = list() + # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access @@ -133,6 +152,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): group_params = list() for param in param_group["params"]: if param.requires_grad: + if self.moe_extra_dp_pg is None: + # skip moe param + if is_moe_tensor(param): + moe_params.append(param) + continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -146,6 +170,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # managed by this data parallel rank param_group["params"] = master_param_current_rank + # if there are moe params, store in addtional group in optim + if len(moe_params) > 0: + param_group = dict() + for key, value in self.optim.param_groups[0].items(): + if key != "params": + param_group[key] = value + param_group["params"] = moe_params + self.optim.param_groups.append(param_group) + # intialize communication stream for # communication-compuation overlapping if self._overlap_communication: @@ -208,13 +241,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) - splited_params = padding_param.split(padding_param.numel() // self._world_size) + + if self.moe_extra_dp_pg is not None and is_moe_tensor(param): + splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size) + splited_params = splited_params[self.moe_extra_dp_pg_rank] + else: + splited_params = padding_param.split(padding_param.numel() // self._world_size) + splited_params = splited_params[self._local_rank] # use fp32 when master_weights is True if self._master_weights is True: - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + splited_param_current_rank = splited_params.detach().float().to(device) else: - splited_param_current_rank = splited_params[self._local_rank] + splited_param_current_rank = splited_params + params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -247,8 +287,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._bucket_store.num_elements_in_bucket() > 0: self._bucket_store.build_grad_in_bucket() - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size + if self.moe_extra_dp_pg is None: + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size + else: + # record moe and non moe param + moe_list = [] + for param in self._bucket_store._param_list: + moe_list.append(is_moe_tensor(param)) + + # divide them into different groups + moe_grad_list = [] + non_moe_grad_list = [] + for grad_list in self._bucket_store._grad_in_bucket.values(): + non_moe_cur_grad = [] + moe_cur_grad = [] + for i in range(len(grad_list)): + if moe_list[i] == True: + moe_cur_grad.append(grad_list[i]) + else: + non_moe_cur_grad.append(grad_list[i]) + if len(moe_cur_grad) > 0: + moe_grad_list.append(moe_cur_grad) + if len(non_moe_cur_grad) > 0: + non_moe_grad_list.append(non_moe_cur_grad) + + if len(non_moe_grad_list) > 0: + non_moe_flat_grads = [] + for grad_list in non_moe_grad_list: + non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) + non_moe_flat_grads /= self._world_size + + if len(moe_grad_list) > 0: + moe_flat_grads = [] + for grad_list in moe_grad_list: + moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -256,7 +331,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._overlap_communication: stream = self._comm_stream # in case of the memory being reused in the default stream - flat_grads.record_stream(stream) + if self.moe_extra_dp_pg is None: + flat_grads.record_stream(stream) + else: + if len(non_moe_grad_list) > 0: + non_moe_flat_grads.record_stream(stream) + if len(moe_grad_list) > 0: + moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(torch.cuda.current_stream()) else: @@ -265,49 +346,108 @@ class LowLevelZeroOptimizer(OptimizerWrapper): with torch.cuda.stream(stream): group_id = self._bucket_store.current_group_id - grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) + if self.moe_extra_dp_pg is None: + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - dist.all_reduce(flat_grads, group=self.dp_pg) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) + if self.moe_extra_dp_pg is None: + dist.all_reduce(flat_grads, group=self.dp_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) - for rank, grad_list in grad_in_bucket.items(): - sync_tensor(flat_grads_per_rank[rank], grad_list) - for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if ( - len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) - < self._world_size - ): - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + # sync extra zero group + else: + # sync non moe param in global dp group + if len(non_moe_grad_list) > 0: + dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + flat_grads_per_rank = non_moe_flat_grads.split( + non_moe_flat_grads.numel() // self._world_size + ) + self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) + + # sync moe param only in zero group + if len(moe_grad_list) > 0: + dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) + flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) + self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + if self.moe_extra_dp_pg is None: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - sync_tensor(recieved_grad, grad_in_bucket_current_rank) - for grad in grad_in_bucket_current_rank: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) + else: + # categorize moe and non moe param + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + moe_grad_in_bucket_current_rank = [] + non_moe_grad_in_bucket_current_rank = [] + for idx, grad in enumerate(grad_in_bucket_current_rank): + if moe_list[idx] == True: + moe_grad_in_bucket_current_rank.append(grad) + else: + non_moe_grad_in_bucket_current_rank.append(grad) + + if len(non_moe_grad_list) > 0: + flat_grads_list = list( + non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) + ) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + self._update_partitoned_grad( + non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 + ) + + if len(moe_grad_list) > 0: + flat_grads_list = list( + moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) + ) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) + param_slice = self._world_size // self.moe_extra_dp_pg_size + recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) + for split_recieved_grad in recieved_grad: + split_recieved_grad = _unflatten_dense_tensors( + split_recieved_grad, moe_grad_in_bucket_current_rank + ) + for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._add_grad(real_grad, param_slice, group_id, param_id) self._bucket_store.reset() + def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: + for rank, grad_list in enumerate(origin_grad_list): + sync_tensor(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, self._world_size, group_id, param_id, rank) + + def _update_partitoned_grad( + self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + ) -> None: + sync_tensor(flat_grad, origin_grad_list) + for grad in origin_grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, partition_num, group_id, param_id) + + def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + def _add_to_bucket(self, param, group_id): param_size = param.numel() @@ -424,13 +564,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # else the splited grad should be attached to the splited param grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: - real_working_params[group_id].append(working_param) + # moe hybrid zero + if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): + real_working_params[group_id].append(working_param) + if self._partition_grads: + grad = grads + else: + param_slice = self._world_size // self.moe_extra_dp_pg_size + grad = grads[ + self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice + ] + grad = flatten(grad) + else: + real_working_params[group_id].append(working_param) + grad = grads[grad_index] # no need to copy fp32 grad if master_weights is False - grad = ( - grads[grad_index].to(splited_param.dtype).to(splited_param.device) - if self._master_weights - else grads[grad_index] - ) + if self._master_weights: + grad = grad.to(splited_param.dtype).to(splited_param.device) splited_param.grad = grad grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) @@ -449,24 +599,43 @@ class LowLevelZeroOptimizer(OptimizerWrapper): global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) + # TODO: we should store master param for ep + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]["params"]: + param.data = param.data.to(torch.float32) + param.grad = param.grad.to(torch.float32) + # update the parameters self.optim.step() + # release the moe gradm + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]["params"]: + param.grad = None + param.data = param.data.to(self._dtype) + # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - # dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size) - ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) + if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + for _ in range(self.moe_extra_dp_pg_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg) + else: + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -488,7 +657,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -596,10 +764,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] - gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) - ] - dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + gather_tensor = [ + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg) + else: + gather_tensor = [ + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -624,8 +798,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // self._world_size) - zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() + if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + v_list = v.split(v.numel() // self.moe_extra_dp_pg_size) + zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone() + else: + v_list = v.split(v.numel() // self._world_size) + zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -656,8 +834,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)] - dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + state_tensor = [ + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + ] + dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg) + else: + state_tensor = [ + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -688,7 +874,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + if self.moe_extra_dp_pg is not None and is_moe_tensor(p): + master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) + else: + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md new file mode 100644 index 000000000..a0821a533 --- /dev/null +++ b/examples/language/openmoe/README.md @@ -0,0 +1,129 @@ +## OpenMoE +[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods. + +## Usage + +### 1. Installation + +Please install the latest ColossalAI from source. + +```bash +CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI +``` + +Then install dependencies. + +```bash +cd ColossalAI/examples/language/openmoe +pip install -r requirements.txt +``` + +Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. + +### 2. Install kernels (Optional) + +We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware. +``` +# install triton via pip +pip install triton + +# install flash attention via pip +pip install flash-attn==2.0.5 + +# install apex from source +git clone https://github.com/NVIDIA/apex.git +cd apex +git checkout 741bdf50825a97664db08574981962d66436d16a +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext" +``` + +### 3. Train +Yon can use colossalai run to launch single-node training: +```bash +colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS +``` +Yon can also use colossalai run to launch multi-nodes training: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS +``` + +Here is a sample hostfile: + +```text +hostname1 +hostname2 +hostname3 +hostname4 +``` + +The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: + +- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE. +- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training. +- Output path: `--output_path`. The path to save your model. The default value is `./outputs`. +- Number of epochs: `--num_epochs`. The default value is 1. +- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1. +- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported. +- Max length: `--max_length`. Max sequence length. Default to 2048. +- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it. +- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`. +- Learning rate: `--lr`. The default value is 1e-5. +- Weight decay: `--weight_decay`. The default value is 0. +- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero. +- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4. +- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed. +- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed. +- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details. +- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details. +- Label smoothing: `--label_smoothing`. Label smoothing. +- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor. +Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling. +- Load balance interval: `--load_balance_interval`. Expert load balance interval. +- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training. + +### 4. Shell Script Examples + +For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training +OpenMoE. + +#### a. Running environment +This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink. + +#### b. Running command +We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args. + +```bash +bash train.sh +``` + +#### c. Multi-Nodes Training + +To run on multi-nodes, you can modify the script as: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +train.py --OTHER_CONFIGURATIONS +``` + +## Reference +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +```bibtex +@misc{openmoe2023, + author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You}, + title = {OpenMoE: Open Mixture-of-Experts Language Models}, + year = {2023}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}}, +} +``` diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py new file mode 100644 index 000000000..112a12cb6 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -0,0 +1,296 @@ +import argparse +import json +import os + +import torch +import torch.distributed as dist +from huggingface_hub import snapshot_download +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args +from model.openmoe_policy import OpenMoeForCausalLMPolicy +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import T5Tokenizer +from transformers.models.llama import LlamaConfig +from utils import PerformanceEvaluator, get_model_numel + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): + ckpt_path = snapshot_download(repo_name) + # single ckpt + if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin") + # shard ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + booster.load_model(model, ckpt_path) + + +class RandomDataset(Dataset): + def __init__( + self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None + ): + self.num_samples = num_samples + self.max_length = max_length + if os.path.exists("./mock_data.json"): + self.input_ids = [] + self.attention_mask = [] + with open("./mock_data.json", "r") as f: + data = json.load(f) + for v in data.values(): + d = v["text"] + encode = tokenizer( + "" + d, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + truncation=True, + padding="max_length", + ) + self.input_ids.append(encode["input_ids"]) + self.attention_mask.append(encode["attention_mask"]) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + repeat_times = num_samples // self.input_ids.shape[0] + 1 + self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] + self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] + else: + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b"], + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--seq_length", + type=int, + default=2048, + help="sequence length for the training dataloader.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + help="parallel plugin", + ) + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--dp_size", type=int, default=1, help="dp size") + parser.add_argument("--ep_size", type=int, default=2, help="ep size") + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + parser.add_argument("--extra_dp_size", type=int, default=1) + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + ) + # bench + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--active", type=int, default=20) + # load balance + parser.add_argument("--load_balance", action="store_true") + + # overlap + parser.add_argument("--overlap_alltoall", action="store_true") + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set plugin + booster_kwargs = {} + hybrid_dict = { + "tp_size": 1, + "custom_policy": OpenMoeForCausalLMPolicy(), + "enable_fused_normalization": args.use_kernel, + "enable_jit_fused": args.use_kernel, + "precision": "bf16", + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, + } + if args.plugin == "ep": + dp_size = dist.get_world_size() + plugin = MoeHybridParallelPlugin( + pp_size=1, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size, + **mgr_dict, + ) + elif args.plugin == "ep_zero": + dp_size = dist.get_world_size() + use_ep_inside = False + plugin = MoeHybridParallelPlugin( + pp_size=1, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, + ) + elif args.plugin == "hybrid": + dp_size = dist.get_world_size() // args.pp_size + plugin = MoeHybridParallelPlugin( + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + **mgr_dict, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin}") + + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_load_balance=args.load_balance, + enable_kernel=args.use_kernel, + enable_comm_overlap=args.overlap_alltoall, + ) + with skip_init(): + model = OpenMoeForCausalLM(config) + coordinator.print_on_master(f"Finish init model with config:\n{config}") + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + dataset = RandomDataset( + num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size, + max_length=args.seq_length, + tokenizer=tokenizer, + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5) + + model_numel = get_model_numel(model) + performance_evaluator = PerformanceEvaluator( + model_numel, + enable_grad_checkpoint=True, + ignore_steps=args.warmup, + dp_world_size=dp_size, + ) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + load_ckpt(repo_name, model, booster) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + coordinator.print_on_master(f"Finish init booster") + + # Start finetuning + coordinator.print_on_master(f"Start training") + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) - 1 + exmaple_data = next(train_dataloader_iter) + with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar: + for step in pbar: + performance_evaluator.on_step_start(step) + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + # Forward pass + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(exmaple_data["input_ids"]) + if (step == args.warmup // 2) and args.load_balance: + coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) + performance_evaluator.on_fit_end() + + +if __name__ == "__main__": + main() diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh new file mode 100755 index 000000000..f269e260d --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +set -xue + +NUM_GPU=8 +MODEL="8b" +SEQ_LENGTH=2048 +WARMUP=20 +ACTIVE=4 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + + +# ep +echo -e "\n\n Naive EP \n\n" +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 8 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep \ + --zero_stage 2 + + +# ep_zero +echo -e "\n\n EP-ZERO \n\n" +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 16 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance + +echo -e "\n\n EP-ZERO + Overlap \n\n" +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 16 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance \ + --overlap_alltoall + + +# hybrid +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 128 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --use_kernel \ + --plugin hybrid \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 4 \ + --zero_stage 1 \ + --microbatch_size 32 diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh new file mode 100755 index 000000000..06d57e4f0 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +set -xue + +NUM_GPU=8 +MODEL="8b" +SEQ_LENGTH=2048 +WARMUP=20 +ACTIVE=4 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + + +# ep +echo -e "\n\n Naive EP \n\n" +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 12 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep \ + --zero_stage 2 + + +# ep_zero +echo -e "\n\n EP-ZERO \n\n" +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 20 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance \ + --overlap_alltoall diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py new file mode 100644 index 000000000..45a11ad63 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -0,0 +1,139 @@ +import argparse +import functools +import os + +import torch +import torch.distributed as dist +import tqdm +from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from transformers.models.llama import LlamaConfig +from utils import PerformanceEvaluator, get_model_numel + +from colossalai.moe.manager import MOE_MANAGER + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def fsdp_main(rank, world_size, args): + # initialize the process group + + # initialize the process group + dist.init_process_group("nccl") + + MOE_MANAGER.setup(seed=42, parallel=None) + + dp_size = dist.get_world_size() + dataset = RandomDataset( + max_length=args.seq_length, + num_samples=args.batch_size * (args.warmup + args.active) * dp_size, + ) + sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) + train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} + train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) + torch.cuda.set_device(rank) + + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_load_balance=False, + enable_kernel=False, + enable_comm_overlap=False, + ) + torch.set_default_dtype(torch.float16) + model = OpenMoeForCausalLM(config) + torch.set_default_dtype(torch.float32) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + OpenMoeDecoderLayer, + }, + ) + model = FSDP( + model, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + ) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) + model.train() + + model_numel = get_model_numel(model) + performance_evaluator = PerformanceEvaluator( + model_numel, + enable_grad_checkpoint=True, + ignore_steps=args.warmup, + dp_world_size=dist.get_world_size(), + ) + + for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)): + performance_evaluator.on_step_start(step) + input_ids, attention_mask, labels = ( + data["input_ids"].cuda(), + data["attention_mask"].cuda(), + data["labels"].cuda(), + ) + + optimizer.zero_grad() + output = model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + chunk_head=False, + ) + loss = output["loss"] + loss.backward() + optimizer.step() + performance_evaluator.on_step_end(input_ids) + + performance_evaluator.on_fit_end() + if dist.get_rank() == 0: + print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b"], + help="base or 8b", + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=2048) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--active", type=int, default=20) + args = parser.parse_args() + + torch.manual_seed(42) + + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + fsdp_main(local_rank, world_size, args) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh new file mode 100755 index 000000000..c6f5624dd --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -xue + +MODEL="8b" +BATCH_SIZE=1 +SEQ_LENGTH=2048 +WARMUP=8 +ACTIVE=4 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +# single node +torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE + +# multi node +torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \ + $example_dir/benchmark/benchmark_fsdp.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/examples/language/openmoe/benchmark/hostfile.txt new file mode 100644 index 000000000..994b3e2cf --- /dev/null +++ b/examples/language/openmoe/benchmark/hostfile.txt @@ -0,0 +1,2 @@ +host1 +host2 diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py new file mode 100644 index 000000000..7a0955bb0 --- /dev/null +++ b/examples/language/openmoe/benchmark/utils.py @@ -0,0 +1,126 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor + +from colossalai.logging import DistributedLogger + + +def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = "Model param count: " + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + if model_param >= B: + outputs += f"{model_param / B:.2f} B\n" + elif model_param >= M: + outputs += f"{model_param / M:.2f} M\n" + elif model_param >= K: + outputs += f"{model_param / K:.2f} K\n" + else: + outputs += f"{model_param}\n" + logger.info(outputs, ranks=[0]) + + +def get_model_numel(model: nn.Module) -> None: + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + return model_param + + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + self.dp_world_size = dp_world_size + self.world_size = dist.get_world_size() + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + torch.cuda.synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + torch.cuda.synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + mp_world_size = self.world_size // self.dp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + if dist.get_rank() == 0: + print( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}") + print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}") diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py new file mode 100644 index 000000000..db90c6e34 --- /dev/null +++ b/examples/language/openmoe/infer.py @@ -0,0 +1,57 @@ +from argparse import ArgumentParser + +import torch +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args +from transformers import T5Tokenizer +from transformers.models.llama import LlamaConfig + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"]) + return parser.parse_args() + + +def inference(args): + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + if args.model == "test": + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_kernel=True) + model = OpenMoeForCausalLM(config) + else: + config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_kernel=False) + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) + model = model.eval().bfloat16() + model = model.to(torch.cuda.current_device()) + + input_str = """``` +y = list(map(int, ['1', 'hello', '2'])) +``` +What error does this program produce? +ValueError: invalid literal for int() with base 10: 'hello' + +``` +sum = 0 +for i in range(100): + sum += i +``` +What is the value of sum immediately after the 10th time line 3 is executed?""" + + # print("model config: ", model.config) + input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) + input_ids = input_ids.input_ids.to(torch.cuda.current_device()) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64) + out = tokenizer.decode(generation_output[0], skip_special_tokens=False) + print(f"output: \n{out}\n") + + +if __name__ == "__main__": + args = parse_args() + inference(args) diff --git a/examples/language/openmoe/infer.sh b/examples/language/openmoe/infer.sh new file mode 100644 index 000000000..a578203eb --- /dev/null +++ b/examples/language/openmoe/infer.sh @@ -0,0 +1 @@ +python infer.py --model "base" diff --git a/examples/language/openmoe/model/__init__.py b/examples/language/openmoe/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py new file mode 100644 index 000000000..20b1e780d --- /dev/null +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2022 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert T5X checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example: + `gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use + https://huggingface.co/google/t5-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/t5_1_1_small_pt + ``` +""" + +import argparse +import collections + +import torch +from flax import traverse_util +from modeling_openmoe import OpenMoeForCausalLM +from t5x import checkpoints +from transformers import LlamaConfig +from transformers.utils import logging + +logging.set_verbosity_info() + + +def t5x_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] + o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] + q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] + v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] + return k, o, q, v + + +def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] + return wi, wo + + +def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"] + return wi, wo + + +def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"] + return wi, wo + + +def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"] + + +def t5x_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/layers_{i}/{layer_name}/scale"] + + +def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int): + """Converts the parameters from T5X-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = True + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + print(old.keys()) + for key, value in old.items(): + print(f"{key}: {value.shape}") + + # Shared embeddings. + new["model.embed_tokens.weight"] = old["token_embedder/embedding"] + + # Decoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") + new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm + new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T + new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T + new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T + new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm + + if (i + 1) % moe_interval == 0: + # moe + gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.gate_weight"] = gate.T + wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0] + new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1] + new[f"model.layers.{i}.mlp.experts.wo"] = wo + # extra + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm") + new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm + wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T + else: + wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T + + new["model.norm.weight"] = old["decoder/decoder_norm/scale"] + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + return state_dict + + +def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): + """Replaces the params in model witht the T5X converted params.""" + variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + converted = convert_t5x_to_pytorch(variables, + num_layers=config.num_hidden_layers, + moe_interval=config.moe_layer_interval) + state_dict = make_state_dict(converted) + model.load_state_dict(state_dict, strict=True) + + +def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path): + """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = LlamaConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use T5Model, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + model = OpenMoeForCausalLM(config) + + # Load weights from tf checkpoint + load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument("--t5x_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the T5X checkpoint.") + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", + ) + parser.add_argument("--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/examples/language/openmoe/model/convert_openmoe_ckpt.sh new file mode 100644 index 000000000..c0d53f562 --- /dev/null +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.sh @@ -0,0 +1 @@ +python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py new file mode 100644 index 000000000..7e3e6b3ed --- /dev/null +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -0,0 +1,1113 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OpenMoE model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON +from colossalai.moe.layers import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation, set_moe_args + +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def set_openmoe_args( + config: LlamaConfig, + num_experts: int, + moe_layer_interval: int, + router_topk: int = 2, + router_capacity_factor_train: float = 1.25, + router_capacity_factor_eval: float = 2.0, + router_min_capacity: int = 4, + router_noisy_policy: str = None, + router_drop_tks: bool = True, + router_aux_loss_factor: float = 0.01, + router_z_loss_factor: float = 0.0001, + mlp_gated: bool = True, + label_smoothing: float = 0.001, + z_loss_factor: float = 0.01, + enable_load_balance: bool = False, + load_balance_tolerance: float = 0.1, + load_balance_beam_width: int = 8, + load_balance_group_swap_factor: float = 0.4, + enable_kernel: bool = False, + enable_comm_overlap: bool = False, +) -> None: + """ + MoE related arguments. + It inserts the MoE arguments into the Llama config. + + Args: + config (LlamaConfig): Transformers Llama config. + num_experts (int, optional): Number of experts. + moe_layer_interval (int, optional): The interval moe layer. + router_topk (int, optional): Moe router top k. Defaults to 2. + router_capacity_factor_train (float, optional): Moe router max capacity for train. Defaults to 1.25. + router_capacity_factor_eval (float, optional): Moe router max capacity for eval. Defaults to 2.0. + router_min_capacity (int, optional): Moe router min capacity. Defaults to 4. + router_noisy_policy (str, optional): Moe router noisy policy. You can choose [Jitter, Gaussian, None]. Defaults to None. + router_drop_tks (bool, optional): Whether moe router drop tokens which exceed max capacity. Defaults to True. + router_aux_loss_factor (float, optional): Moe router aux loss. You can refer to STMoE for details. Defaults to 0.01. + router_z_loss_factor (float, optional): Moe router z loss. You can refer to STMoE for details. Defaults to 0.01. + mlp_gated (bool, optional): Use gate in mlp. Defaults to True. + label_smoothing (float, optional): Label smoothing. Defaults to 0.001. + z_loss_factor (float, optional): The final outputs' classification z loss factor. Defaults to 0.01. + enable_load_balance (bool, optional): Expert load balance. Defaults to False. + load_balance_tolerance (float, optional): Expert load balance search's difference tolerance. Defaults to 0.1. + load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8. + load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4. + enable_kernel (bool, optional): Use kernel optimization. Defaults to False. + enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False. + """ + moe_args = dict( + num_experts=num_experts, + moe_layer_interval=moe_layer_interval, + router_topk=router_topk, + router_capacity_factor_train=router_capacity_factor_train, + router_capacity_factor_eval=router_capacity_factor_eval, + router_min_capacity=router_min_capacity, + router_noisy_policy=router_noisy_policy, + router_drop_tks=router_drop_tks, + router_aux_loss_factor=router_aux_loss_factor, + router_z_loss_factor=router_z_loss_factor, + mlp_gated=mlp_gated, + label_smoothing=label_smoothing, + z_loss_factor=z_loss_factor, + enable_load_balance=enable_load_balance, + load_balance_tolerance=load_balance_tolerance, + load_balance_beam_width=load_balance_beam_width, + load_balance_group_swap_factor=load_balance_group_swap_factor, + enable_kernel=enable_kernel, + enable_comm_overlap=enable_comm_overlap, + ) + set_moe_args(config, moe_args) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0): + """Generate Sin/Cos for Rotary Embeddings. + + Args: + features: an integer + length: an integer + min_timescale: an optional float + max_timescale: an optional float + + Returns: + output_sin: a float32 Tensor with shape [length, features] + output_cos: a float32 Tensor with shape [length, features] + """ + fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features + timescale = min_timescale * (max_timescale / min_timescale) ** fraction + rotational_frequency = 1.0 / timescale + + sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) + + sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) + + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): + """Helper function to apply Rotary Embeddings.""" + cos = cos.to(q.dtype) + sin = sin.to(q.dtype) + + if len(k.shape) == 3: + # for multi query attention + k = k.unsqueeze(2) + multiquery = True + else: + multiquery = False + + batch, qlen, qheads, d = q.shape + kbatch, klen, kheads, kd = k.shape + assert batch == kbatch, f"{batch} != {kbatch}" + assert d == kd, f"{d} != {kd}" + if decode and qlen == 1 and rotary_index is not None: + qcos = cos[rotary_index + 1, :] + qsin = sin[rotary_index + 1, :] + qcos = qcos.unsqueeze(2) + qsin = qsin.unsqueeze(2) + kcos, ksin = cos[:klen, :], sin[:klen, :] + kcos = kcos.unsqueeze(0).unsqueeze(2) + ksin = ksin.unsqueeze(0).unsqueeze(2) + else: + qcos, qsin = cos[:qlen, :], sin[:qlen, :] + qcos = qcos.unsqueeze(0).unsqueeze(2) + qsin = qsin.unsqueeze(0).unsqueeze(2) + kcos, ksin = qcos, qsin + + out_q = (q * qcos) + (rotate_half(q) * qsin) + out_k = (k * kcos) + (rotate_half(k) * ksin) + + if multiquery: + out_k = out_k.squeeze(2) + + return out_q, out_k + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) + + +class OpenMoeMLP(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.hidden_act = config.hidden_act + self.act_fn = get_activation(self.hidden_act) + self.use_kernel = config.enable_kernel + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + down_proj = sum(down_proj) + else: + if HAS_TRITON and self.use_kernel and self.hidden_act == "swiglu": + down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class OpenMoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + use_kernel: bool = True, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + max_length = max(query_states.shape[1], key_states.shape[1]) + assert max_length <= self.sin.shape[0] + sin, cos = self.sin[:max_length], self.cos[:max_length] + # TODO: for inference, we can add emb kv into cache to avoid computation + query_states, key_states = apply_rotary_embedding( + query_states, key_states, cos, sin, decode=True if q_len == 1 else False, rotary_index=position_ids + ) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if HAS_FLASH_ATTN and use_kernel: + from flash_attn import flash_attn_func + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = flash_attn_func(query_states, key_states, value_states, softmax_scale=1.0, causal=True) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + if self.training: + attention_mask = attention_mask.clone().detach() + attention_mask[:, :, :, 0] = 0 + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class OpenMoeDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, moe: bool): + super().__init__() + self.hidden_size = config.hidden_size + self.moe = moe + self.self_attn = OpenMoeAttention(config=config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.moe: + self.mlp = SparseMLP( + num_experts=config.num_experts, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + router_top_k=config.router_topk, + router_capacity_factor_train=config.router_capacity_factor_train, + router_capacity_factor_eval=config.router_capacity_factor_eval, + router_min_capacity=config.router_min_capacity, + router_noisy_policy=config.router_noisy_policy, + router_drop_tks=config.router_drop_tks, + mlp_activation=config.hidden_act, + mlp_gated=config.mlp_gated, + enable_load_balance=config.enable_load_balance, + load_balance_tolerance=config.load_balance_tolerance, + load_balance_beam_width=config.load_balance_beam_width, + load_balance_group_swap_factor=config.load_balance_group_swap_factor, + enable_kernel=config.enable_kernel, + enable_comm_overlap=config.enable_comm_overlap, + ) + self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.extra_mlp = OpenMoeMLP(config) + else: + self.mlp = OpenMoeMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.moe: + residual = hidden_states + hidden_states = self.pre_extra_mlp_layernorm(hidden_states) + hidden_states = self.extra_mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class OpenMoePreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, OpenMoeModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class OpenMoeModel(OpenMoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OpenMoeForCausalLM(OpenMoePreTrainedModel): + # _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OpenMoeModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + chunk_head: Optional[bool] = True, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + # reset moe loss + MOE_MANAGER.reset_loss() + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + + loss = None + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss = self._calculate_loss(shift_logits, shift_labels) + return loss + + return custom_forward + + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx : batch_idx + 1, :], + labels[batch_idx : batch_idx + 1, :], + ) + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): + if aux_loss is None or z_loss is None: + aux_loss, z_loss = MOE_MANAGER.get_loss() + assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval + aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) + z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) + return aux_loss, z_loss + + def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Compute cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + + Returns: + Tuple of scalar loss. + """ + if len(logits.shape) != len(targets.shape) + 1: + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) + ) + vocab_size = logits.shape[-1] + confidence = 1.0 - self.config.label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * math.log(confidence) + (vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20) + ) + + # one hot + soft_targets = targets[..., None] == torch.arange(vocab_size, device=targets.device).reshape( + (1,) * len(targets.shape) + (-1,) + ) + soft_targets = torch.where( + soft_targets, torch.full_like(soft_targets, confidence), torch.full_like(soft_targets, low_confidence) + ) + soft_targets = soft_targets.to(torch.float32) + + # cross entropy + total_loss = ZLossCrossEntropy.apply(logits, soft_targets, self.config.z_loss_factor) + total_loss = total_loss - normalizing_constant + total_loss = torch.mean(torch.sum(total_loss, dim=-1), dim=0) + return total_loss + + +class ZLossCrossEntropy(torch.autograd.Function): + """Computes cross entropy loss with stable custom gradient. + + Computes a stabilized-gradient version of: + -jnp.sum(targets * nn.log_softmax(logits), axis=-1) + + If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 + will be added to the cross entropy loss (z = softmax normalization constant). + The two uses of z_loss are: + 1. To keep the logits from drifting too far from zero, which can cause + unacceptable roundoff errors in bfloat16. + 2. To encourage the logits to be normalized log-probabilities. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical one-hot targets [batch, length, num_classes] float + array. + z_loss: coefficient for auxilliary z-loss loss term. + + Returns: + tuple with the total loss and the z_loss, both + float arrays with shape [batch, length]. + """ + + @staticmethod + def forward(ctx, logits, targets, z_loss): + max_logit = torch.max(logits, dim=-1, keepdim=True)[0] + shifted = logits - max_logit + exp_shifted = torch.exp(shifted) + sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True) + sum_exp_log = torch.log(sum_exp) + log_softmax = shifted - sum_exp_log + loss = -torch.sum(targets * log_softmax, axis=-1) + # Add auxilliary z-loss term. + log_z = torch.squeeze(sum_exp_log + max_logit, axis=-1) + total_z_loss = z_loss * torch.square(log_z) + loss += total_z_loss + ctx.z_loss = z_loss + ctx.save_for_backward(logits, targets, exp_shifted, sum_exp, log_softmax, log_z) + return loss + + @staticmethod + def backward(ctx, *grad_outputs): + assert len(grad_outputs) == 1 + g = grad_outputs[0] + z_loss = ctx.z_loss + logits, targets, exp_shifted, sum_exp, log_softmax, log_z = ctx.saved_tensors + # z-loss term adds the (2 * z_loss * log_z) factor. + deriv = (1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets + g_logits = g.unsqueeze(-1) * deriv + g_targets = -g.unsqueeze(-1) * log_softmax + + return ( + g_logits.to(logits.dtype), + g_targets.to(targets.dtype), + None, + ) diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/examples/language/openmoe/model/openmoe_8b_config.json new file mode 100644 index 000000000..248697c37 --- /dev/null +++ b/examples/language/openmoe/model/openmoe_8b_config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "OpenMoeForCausalLM" + ], + "intermediate_size": 8192, + "hidden_size": 2048, + "num_hidden_layers": 24, + "head_dim": 128, + "num_attention_heads": 24, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "vocab_size": 256384, + "hidden_act": "swiglu", + "num_experts": 32, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": null, + "drop_tks": true, + "expert_parallel": null, + "gated": true, + "moe_layer_interval": 6 +} diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/examples/language/openmoe/model/openmoe_base_config.json new file mode 100644 index 000000000..5a7c97bd1 --- /dev/null +++ b/examples/language/openmoe/model/openmoe_base_config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "OpenMoeForCausalLM" + ], + "intermediate_size": 2048, + "hidden_size": 768, + "num_hidden_layers": 12, + "head_dim": 64, + "num_attention_heads": 12, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "vocab_size": 256384, + "hidden_act": "swiglu", + "num_experts": 16, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": null, + "drop_tks": true, + "expert_parallel": null, + "gated": true, + "moe_layer_interval": 4 +} diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py new file mode 100644 index 000000000..f354bbea9 --- /dev/null +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -0,0 +1,562 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import logging + +from colossalai.moe.manager import MOE_MANAGER +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel + +__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] + + +class OpenMoePolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="pre_extra_mlp_layernorm", + target_module=FusedRMSNorm, + ignore_if_not_exist=True, + ), + ], + policy=policy, + target_key=OpenMoeDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=OpenMoeModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in openmoe.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "OpenMoeModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "OpenMoeModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages + + """ + if num_layers == 24 and num_stages == 4: + return [7, 7, 7, 3] + elif num_layers == 24 and num_stages == 2: + return [15, 9] + elif num_layers == 12 and num_stages == 4: + return [5, 5, 5, 1] + elif num_layers == 12 and num_stages == 2: + return [8, 4] + else: + print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") + return Policy.distribute_layers(num_layers, num_stages) + + +class OpenMoeModelPolicy(OpenMoePolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=OpenMoeModel, + new_forward=OpenMoePipelineForwards.openmoe_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class OpenMoeForCausalLMPolicy(OpenMoePolicy): + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + OpenMoeForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ]) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=OpenMoeForCausalLM, + new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1): + # tie weights + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + }] + return [] + + +class OpenMoePipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def openmoe_model_forward( + self: OpenMoeModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, + ): + # reset moe loss for different data + MOE_MANAGER.reset_loss() + + logger = logging.get_logger(__name__) + + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = (past_key_values[idx] if past_key_values is not None else None) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + # concat past losses with current ones + router_aux_loss, router_z_loss = MOE_MANAGER.get_loss() + if past_router_aux_loss is not None and past_router_z_loss is not None: + router_aux_loss = past_router_aux_loss + router_aux_loss + router_z_loss = past_router_z_loss + router_z_loss + + if stage_manager.is_last_stage(): + return tuple([ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ]) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "router_aux_loss": router_aux_loss, + "router_z_loss": router_z_loss, + } + + @staticmethod + def llama_for_causal_lm_forward( + self: OpenMoeForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + chunk_head: Optional[bool] = True, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OpenMoePipelineForwards.openmoe_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_aux_loss=past_router_aux_loss, + past_router_z_loss=past_router_z_loss, + ) + + if stage_manager.is_last_stage(): + ( + hidden_states, + past_key_values, + all_hidden_states, + attentions, + router_aux_loss, + router_z_loss, + ) = outputs + + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + + loss = None + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss = self._calculate_loss(shift_logits, shift_labels) + return loss + + return custom_forward + + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) + loss = aux_loss + z_loss + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx:batch_idx + 1, :], + labels[batch_idx:batch_idx + 1, :], + ) + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=attentions, + ) + else: + hidden_states = outputs["hidden_states"] + router_aux_loss = outputs["router_aux_loss"] + router_z_loss = outputs["router_z_loss"] + return { + "hidden_states": hidden_states, + "past_router_aux_loss": router_aux_loss, + "past_router_z_loss": router_z_loss, + } diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt new file mode 100644 index 000000000..ccf02ba1d --- /dev/null +++ b/examples/language/openmoe/requirements.txt @@ -0,0 +1,5 @@ +colossalai >= 0.3.3 +torch >= 1.8.1 +transformers >= 4.20.0 +sentencepiece +datasets diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh new file mode 100644 index 000000000..960c83adb --- /dev/null +++ b/examples/language/openmoe/test_ci.sh @@ -0,0 +1,37 @@ +pip install -r requirements.txt + +# inference +python infer.py --model "test" + +# train +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin "ep" \ + --batch_size 1 + +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin "ep_zero" \ + --batch_size 1 \ + --zero_stage 1 \ + --extra_dp_size 2 \ + +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin "ep_zero" \ + --batch_size 1 \ + --zero_stage 2 \ + --extra_dp_size 2 \ + +torchrun --standalone --nproc_per_node 4 train.py \ + --model_name "test" \ + --plugin "hybrid" \ + --num_epoch 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --zero_stage 1 \ + --batch_size 1 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py new file mode 100644 index 000000000..e8c2f6aaa --- /dev/null +++ b/examples/language/openmoe/train.py @@ -0,0 +1,377 @@ +import argparse +import os +from functools import partial +from typing import Dict + +import torch +import torch.distributed as dist +from datasets import load_dataset +from huggingface_hub import snapshot_download +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args +from model.openmoe_policy import OpenMoeForCausalLMPolicy +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import T5Tokenizer +from transformers.models.llama import LlamaConfig + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): + ckpt_path = snapshot_download(repo_name) + # single ckpt + if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin") + # shard ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + booster.load_model(model, ckpt_path) + + +def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict: + texts = ["" + sample["prompt"] + sample["completion"] for sample in batch] + data = tokenizer( + texts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_length, + add_special_tokens=False, + ) + data = {k: v.cuda() for k, v in data.items()} + data["labels"] = data["input_ids"].clone() + return data + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b", "test"], + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["ep", "ep_zero", "hybrid"], + help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./outputs", + help="The path of your saved model after finetuning.", + ) + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help=" The interval (steps) of saving checkpoints.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--dataset", + type=str, + default="yizhongw/self_instruct", + help="dataset name from `datasets` repo.", + ) + parser.add_argument( + "--task_name", + type=str, + default="super_natural_instructions", + help="task of corresponding dataset.", + ) + + # optim + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") + # ep_zero plugin + parser.add_argument( + "--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4." + ) + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin") + parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") + parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + + # loss + parser.add_argument( + "--router_aux_loss_factor", + type=float, + default=0.01, + help="Moe router z loss. You can refer to STMoE for details.", + ) + parser.add_argument( + "--router_z_loss_factor", + type=float, + default=0.0001, + help="Moe router aux loss. You can refer to STMoE for details.", + ) + parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.") + parser.add_argument( + "--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor." + ) + + # load balance + parser.add_argument( + "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable." + ) + parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.") + # communicate overlap + parser.add_argument( + "--comm_overlap", + action="store_true", + help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + test_mode = args.model_name == "test" + + # Set plugin + booster_kwargs = {} + hybrid_dict = { + "tp_size": 1, + "custom_policy": OpenMoeForCausalLMPolicy(), + "enable_fused_normalization": args.use_layernorm_kernel, + "enable_jit_fused": args.use_kernel, + "precision": args.precision, + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, + } + if args.plugin == "ep": + dp_size = dist.get_world_size() + plugin = MoeHybridParallelPlugin( + pp_size=1, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size, + **mgr_dict, + ) + elif args.plugin == "ep_zero": + dp_size = dist.get_world_size() + use_ep_inside = False + plugin = MoeHybridParallelPlugin( + pp_size=1, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, + ) + elif args.plugin == "hybrid": + dp_size = dist.get_world_size() // args.pp_size + plugin = MoeHybridParallelPlugin( + pp_size=args.pp_size, + microbatch_size=args.microbatch_size, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + **mgr_dict, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build OpenMoe model + if test_mode: + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config.hidden_size = 128 + config.intermediate_size = 256 + config.vocab_size = 32000 + else: + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + router_aux_loss_factor=args.router_aux_loss_factor, + router_z_loss_factor=args.router_z_loss_factor, + z_loss_factor=args.z_loss_factor, + enable_load_balance=args.load_balance, + enable_comm_overlap=args.comm_overlap, + enable_kernel=args.use_kernel, + ) + with skip_init(): + model = OpenMoeForCausalLM(config) + coordinator.print_on_master(f"Finish init model with config:\n{config}") + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + if test_mode: + dataset = RandomDataset(num_samples=20, tokenizer=tokenizer) + collate_fn = None + else: + dataset = load_dataset(args.dataset, args.task_name) + dataset = dataset["train"] + collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length) + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + if not test_mode: + load_ckpt(repo_name, model, booster) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + coordinator.print_on_master(f"Finish init booster") + + # Start finetuning + coordinator.print_on_master(f"Start finetuning") + for epoch in range(args.num_epoch): + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master(), + ) as pbar: + for step in pbar: + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + # Forward pass + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + + # Apply load balance + if ( + args.load_balance + and args.load_balance_interval > 0 + and (step + 1) % args.load_balance_interval == 0 + ): + coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) + # save ckeckpoint + if (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + booster.save_model(model, args.output_path, shard=True) + + # save checkpoint at the end of each epochs + booster.save_model(model, args.output_path, shard=True) + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + + # Finish training + coordinator.print_on_master(f"Finish training") + + +if __name__ == "__main__": + main() diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh new file mode 100644 index 000000000..91cd3db8d --- /dev/null +++ b/examples/language/openmoe/train.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -xue + +NUM_GPU=8 +MODEL="8b" +SEQ_LENGTH=2048 +BATCH_SIZE=1 +LR=0.00001 + +# ep zero +torchrun --standalone --nproc_per_node $NUM_GPU train.py \ + --num_epoch 1 \ + --model_name $MODEL \ + --plugin "ep_zero" \ + --batch_size $BATCH_SIZE \ + --lr $LR \ + --zero_stage 1 \ + --extra_dp_size 2 + +# ep +# torchrun --standalone --nproc_per_node $NUM_GPU train.py \ +# --num_epoch 1 \ +# --model_name $MODEL \ +# --plugin "ep_zero" \ +# --batch_size $BATCH_SIZE \ +# --lr $LR \ +# --zero_stage 1 + +# hybrid +# torchrun --standalone --nproc_per_node $NUM_GPU train.py \ +# --num_epoch 1 \ +# --model_name $MODEL \ +# --plugin "hybrid" \ +# --batch_size $BATCH_SIZE \ +# --lr $LR \ +# --zero_stage 1 \ +# --pp_size 2 \ +# --dp_size 1 \ +# --ep_size 2 \ diff --git a/pytest.ini b/pytest.ini index 38ad7d76d..598e0a74e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,4 @@ markers = dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) -addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy +addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_fx --ignore=tests/test_legacy diff --git a/tests/test_infer_ops/triton/test_llama_act_combine.py b/tests/test_infer_ops/triton/test_llama_act_combine.py new file mode 100644 index 000000000..5341aa35a --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama_act_combine.py @@ -0,0 +1,56 @@ +import pytest +import torch +from packaging import version +from torch import nn + +from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + +try: + import triton + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +BATCH_SIZE = 4 +SEQ_LEN = 16 +HIDDEN_SIZE = 32 + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_llama_act_combine(dtype: str): + x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() + x_gate_torch = nn.Parameter(x_gate.detach().clone()) + x_gate_kernel = nn.Parameter(x_gate.detach().clone()) + x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() + x_up_torch = nn.Parameter(x_up.detach().clone()) + x_up_kernel = nn.Parameter(x_up.detach().clone()) + + torch_out = SwiGLU(x_gate_torch) * x_up_torch + kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) + atol = 1e-5 if dtype == torch.float32 else 5e-2 + assert torch.allclose(torch_out, kernel_out, atol=atol) + + torch_out.mean().backward() + kernel_out.mean().backward() + assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) + assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) + assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) + + +if __name__ == '__main__': + test_llama_act_combine(torch.float16) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py new file mode 100644 index 000000000..40adeab71 --- /dev/null +++ b/tests/test_moe/moe_utils.py @@ -0,0 +1,169 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler +from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce +from colossalai.legacy.registry import GRADIENT_HANDLER +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor + + +class MoeModel(nn.Module): + def __init__(self, enable_load_balance: bool = False): + class TestSubModule(nn.Module): + def __init__(self): + super().__init__() + self.moe = SparseMLP( + num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance + ) + self.proj = nn.Linear(16, 4) + + def forward(self, x): + x = self.moe(x) + x = self.proj(x) + return x + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_MANAGER.reset_loss() + + x = self.test_embed(x) + x = self.test_transform(x) + + return x + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + if dist.get_world_size() > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1]) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: + bucket_allreduce( + param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group + ) + + +def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + tp_model (MoeModule) + ep_model (MoeModule) + """ + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): + assert tp_name == ep_name + if not is_moe_tensor(tp_param): + if assert_grad_flag: + assert torch.allclose(tp_param, ep_param) + assert torch.allclose(tp_param.grad, ep_param.grad) + else: + tp_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + # get tp param + tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2] + tp_rank = get_ep_rank(tp_param) + tp_dim = tp_dim[0] + 1 + tp_slice = [slice(None)] * tp_dim + [ + slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) + ] + new_tp_param = all_param[tuple(tp_slice)] + if assert_grad_flag: + new_grad = all_grad[tuple(tp_slice)] + if assert_grad_flag: + assert torch.allclose(tp_param, new_tp_param) + assert torch.allclose(tp_param.grad, new_grad) + else: + tp_param.data.copy_(new_tp_param.data) + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + assert local_name == ep_name + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param) + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) + + +def assert_not_equal_in_group(tensor, process_group=None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i + 1] + assert not torch.allclose( + a, b + ), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 8742e5f41..28ee618e1 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,40 +4,58 @@ import torch.distributed as dist import torch.nn as nn import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.engine.gradient_handler import MoeGradientHandler -from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group BATCH_SIZE = 4 DIM = 16 -CONFIG = dict() def run_test(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - expert_module = nn.Linear - expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) - MOE_CONTEXT.setup(42) # MOE initialization - noisy_func = UniformNoiseGenerator() - router = Top1Router(noisy_func=noisy_func) + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - exp = Experts(expert_module, num_experts, **expert_factor) - moe_layer = MoeLayer(DIM, num_experts, router, exp) + moe_layer = SparseMLP( + hidden_size=DIM, + intermediate_size=DIM * 4, + num_experts=num_experts, + router_top_k=1, + router_noisy_policy="Jitter", + ) layer_list.append(moe_layer) model = nn.ModuleList(layer_list) model = model.to(get_current_device()) + dist_dict = MOE_MANAGER.parallel_info_dict + assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) + assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) + assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) + assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) + assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) + assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) + sync_moe_model_param(model) - dist_dict = MOE_CONTEXT.parallel_info_dict - assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group) - assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) + assert_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) # MoE model synchronization passed grad_handler = MoeGradientHandler(model, 0) @@ -47,17 +65,18 @@ def run_test(rank, world_size, port): data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) grad = torch.randn_like(data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() for layer in layer_list: - data, _ = layer(data) + data = layer(data) data.backward(grad) grad_handler.handle_gradient() - assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group) - assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group) - - assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group) - assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[0].experts.wi.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[2].experts.wi.grad, dist_dict[4].dp_group) + assert_equal_in_group(layer_list[2].experts.wo.grad, dist_dict[4].dp_group) # MoE grad handler test passed diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 7a9c551d6..c710c7bf7 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,49 +1,47 @@ import pytest import torch -import torch.nn as nn +import torch.distributed as dist import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.context import ParallelMode -from colossalai.legacy.core import global_context as gpc -from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -BATCH_SIZE = 16 +BATCH_SIZE = 4 NUM_EXPERTS = 4 -CONFIG = dict() def check_equal(tensor_a, tensor_b, atol=1e-06): assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True -def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router): +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1): # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + local_rank = dist.get_rank() - MOE_CONTEXT.setup(42) # MOE environment initialization - MOE_CONTEXT.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed + MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization + MOE_MANAGER.reset_loss() + torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) - expert_module = nn.Linear - expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) - expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) - layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) + layer = SparseMLP(hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_experts=NUM_EXPERTS, + router_top_k=topk, + router_capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine - layer.use_kernel = False - old_out, _ = layer(tokens) + layer.enable_kernel = False + old_out = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) old_out.backward(grad) # get gradient @@ -56,8 +54,8 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f tokens.grad.zero_() layer.gate_weight.grad.zero_() - layer.use_kernel = True - new_out, _ = layer(tokens) # get outputs through colossal kernel + layer.enable_kernel = True + new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) @@ -86,11 +84,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("rs", [131]) @pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) -@pytest.mark.parametrize("router", [Top1Router, Top2Router]) +@pytest.mark.parametrize("topk", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_kernel(rs, hidden_size, data_type, router): - spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) +def test_moe_kernel(rs, hidden_size, data_type, topk): + spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk) -if __name__ == "__main__": - test_moe_kernel(2, 256, torch.float16, Top2Router) +if __name__ == '__main__': + test_moe_kernel(2, 256, torch.float16, 2) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index b7024f32b..b68eaec50 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,50 +1,138 @@ +import importlib import os +import shutil +import sys import pytest import torch import torch.distributed as dist +from transformers.models.llama import LlamaConfig import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe import load_moe_model, save_moe_model +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG + +sys.path.append(os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "examples/language/openmoe", +)) + +OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM +set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args +OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy -def exam_moe_checkpoint(): - with ColoInitContext(device=get_current_device()): - model = MoeModel(checkpoint=True) - save_moe_model(model, "temp_path.pth") +def get_config(): + config = LlamaConfig( + vocab_size=300, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=4, + num_attention_heads=2, + head_dim=4, + dropout_rate=0.0, + hidden_act="swiglu", + ) + set_openmoe_args(config, num_experts=16, moe_layer_interval=1) + return config - with ColoInitContext(device=get_current_device()): - other_model = MoeModel(checkpoint=True) - load_moe_model(other_model, "temp_path.pth") - state_0 = model.state_dict() - state_1 = other_model.state_dict() - for k, v in state_0.items(): - u = state_1.get(k) +def get_model(parallel): + config = get_config() + model = OpenMoeForCausalLM(config) + + if parallel == None: + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=0, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "zero_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=2, + zero_stage=1, + microbatch_size=1, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + return model, booster + + +def _test_moe_checkpoint(parallel, shard): + if parallel == None: + MOE_MANAGER.setup( + seed=42, + parallel=None, + ) + elif parallel == "zero2_ep": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + ) + elif parallel == "hybrid": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=1, + fixed_ep_size=2, + fixed_pp_size=2, + ) + model1, booster1 = get_model(parallel) + model2, booster2 = get_model(parallel) + + if shard: + booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1) + booster2.load_model(model2, "./tmp_ckpt") + else: + booster1.save_model(model1, "tmp_ckpt.pth") + booster2.load_model(model2, "tmp_ckpt.pth") + + state1 = model1.state_dict() + state2 = model2.state_dict() + for k, v in state1.items(): + u = state2.get(k) assert torch.equal(u.data, v.data) if dist.get_rank() == 0: - os.remove("temp_path.pth") + if shard: + shutil.rmtree("./tmp_ckpt") + else: + os.remove("tmp_ckpt.pth") -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - exam_moe_checkpoint() +def _run_dist(rank, world_size, port, parallel, shard): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + _test_moe_checkpoint(parallel, shard) @pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"]) +@pytest.mark.parametrize("shard", [True, False]) @rerun_if_address_is_in_use() -def test_moe_checkpoint(world_size): - spawn(_run_dist) +def test_moe_checkpoint(world_size, parallel, shard): + spawn(_run_dist, world_size, parallel=parallel, shard=shard) if __name__ == "__main__": - test_moe_checkpoint(world_size=4) + test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True) diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py deleted file mode 100644 index 488573b73..000000000 --- a/tests/test_moe/test_moe_colo_init.py +++ /dev/null @@ -1,55 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG - - -@parameterize("init_device_type", ["cpu", "cuda"]) -def exam_moe_colo_init(init_device_type): - world_size = dist.get_world_size() - - if init_device_type == "cuda": - init_device = get_current_device() - elif init_device_type == "cpu": - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - with ColoInitContext(device=init_device): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) - - if hasattr(param, "moe_info"): - param.set_process_group(param.moe_info.pg) - - if hasattr(param, "moe_info"): - assert param.process_group.dp_world_size() == param.moe_info.dp_size - else: - assert param.process_group.dp_world_size() == world_size - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - exam_moe_colo_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_colo_init(world_size): - spawn(_run_dist, world_size) - - -if __name__ == "__main__": - test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py new file mode 100644 index 000000000..11d0664fd --- /dev/null +++ b/tests/test_moe/test_moe_ep_tp.py @@ -0,0 +1,81 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep + + +def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): + assert batch_size % world_size == 0 + + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel=None) + local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel="EP") + ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel="TP") + tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) + ep_model = ep_model.to(get_current_device()) + tp_model = tp_model.to(get_current_device()) + local_model = local_model.to(get_current_device()) + + # sync ep param + sync_moe_model_param(ep_model) + dist_dict = MOE_MANAGER.parallel_info_dict + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) + grad_handler = MoeGradientHandler(ep_model) + # sync tp param + sync_tp_from_ep(tp_model, ep_model) + # sync local param + sync_local_from_ep(local_model, ep_model) + + rank = dist.get_rank() + torch.cuda.manual_seed(seed) + tp_data = torch.randn(batch_size, dim, device=get_current_device()) + micro_batch_size = batch_size // world_size + ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)] + + out_local = local_model(tp_data) + MOE_MANAGER.reset_loss() + out_tp = tp_model(tp_data) + MOE_MANAGER.reset_loss() + out_ep = ep_model(ep_data) + MOE_MANAGER.reset_loss() + assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + + out_local.mean().backward() + out_tp.mean().backward() + out_ep.mean().backward() + grad_handler.handle_gradient() + + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) + + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) + sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) + + +@pytest.mark.dist +@pytest.mark.parametrize("num_experts", [4, 8]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("dim", [32]) +@pytest.mark.parametrize("seed", [42]) +@rerun_if_address_is_in_use() +def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): + spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) + + +if __name__ == '__main__': + test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 300fb6c99..3cd5acc0d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,66 +3,80 @@ import torch.distributed as dist import torch.nn as nn import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import Experts +from colossalai.moe.experts import MLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param -D_MODEL = 4 -D_FF = 8 -CONFIG = dict() +HIDDEN_SIZE = 4 +INTERMEDIATE_SIZE = 8 -def run_test(rank, world_size, port): - world_size = 4 - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - expert_module = nn.Linear - expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) +def run_moe_init(expert_parallel): + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=expert_parallel) + expert_args = dict( + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + expert_parallel=expert_parallel, + ) + exp0 = MLPExperts(1, **expert_args) + exp1 = MLPExperts(2, **expert_args) + exp2 = MLPExperts(4, **expert_args) - MOE_CONTEXT.setup(42) # MOE environment initialization - exp0 = Experts(expert_module, 1, **expert_factor) - exp1 = Experts(expert_module, 2, **expert_factor) - exp2 = Experts(expert_module, 4, **expert_factor) - exp3 = Experts(expert_module, 8, **expert_factor) + if expert_parallel == "EP": + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 1 + assert exp2.num_local_experts == 2 + else: + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 2 + assert exp2.num_local_experts == 4 - assert exp0.num_local_experts == 1 - assert exp1.num_local_experts == 1 - assert exp2.num_local_experts == 1 - assert exp3.num_local_experts == 2 - # experts deployment passed - - parallel_info_dict = MOE_CONTEXT.parallel_info_dict + parallel_info_dict = MOE_MANAGER.parallel_info_dict rank = dist.get_rank() - assert len(parallel_info_dict) == 3 - assert dist.get_rank(parallel_info_dict[4].ep_group) == rank + # group creation assert + assert len(parallel_info_dict) == 2 assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 assert dist.get_rank(parallel_info_dict[1].ep_group) == 0 - assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[1].dp_group) == rank - # group creation passed - model = nn.ModuleList([exp0, exp1, exp2, exp3]) + model = nn.ModuleList([exp0, exp1, exp2]) model = model.to(get_current_device()) sync_moe_model_param(model) - assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group) - assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group) # MOE experts layout success when ep_size = 1 + assert_equal_in_group(exp0.wi.data, parallel_info_dict[1].dp_group) + assert_equal_in_group(exp0.wo.data, parallel_info_dict[1].dp_group) - assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group) - assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group) # MOE experts layout success when ep_size = 2 + assert_equal_in_group(exp1.wi.data, parallel_info_dict[2].dp_group) + assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group) + + +def _run_test(rank, world_size, port, expert_parallel): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_moe_init(expert_parallel) @pytest.mark.dist +@pytest.mark.parametrize("expert_parallel", ["EP", "TP"]) @rerun_if_address_is_in_use() -def test_moe_initialization(): - spawn(run_test, 4) +def test_moe_initialization(expert_parallel): + spawn(_run_test, 2, expert_parallel=expert_parallel) if __name__ == "__main__": - test_moe_initialization() + test_moe_initialization("EP") + test_moe_initialization("TP") diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py new file mode 100644 index 000000000..e9f71d5ca --- /dev/null +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -0,0 +1,97 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeModel + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss / 2) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel() + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") + zero_model = MoeModel() + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_( + torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)] + .detach() + .clone() + ) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + + for (torch_name, torch_param), (zero_name, zero_param) in zip( + torch_model.named_parameters(), zero_model.named_parameters() + ): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + torch_param.data = torch_param.data[ + ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size) + ] + assert torch.allclose( + torch_param.data, zero_param.data, atol=1e-4 + ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_zero_optim(world_size=4) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py new file mode 100644 index 000000000..173a7a356 --- /dev/null +++ b/tests/test_moe/test_moe_load_balance.py @@ -0,0 +1,190 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel + + +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup( + seed=42, + parallel="EP", + ) + zero_model = MoeModel(enable_load_balance=True) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel="EP") + torch_model = MoeModel() + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda().bfloat16() + grad_handler = MoeGradientHandler(torch_model) + + # run to update expert load + data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + + # run torch model twice + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + + # get optim and load status in zero model + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # run again to test + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + torch_optimizer.step() + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + +def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel() + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup( + seed=42, + max_ep_size=2, + use_ep_inside=False, + parallel="EP", + ) + zero_model = MoeModel(enable_load_balance=True) + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_( + torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)] + .detach() + .clone() + ) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + # run torch for twice + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + + # run zero + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # assert out + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + # TODO: high atol, check if bug exists + assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + +def run_dist(rank, world_size, port): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) + run_hybrid_zero_optim_test(rank, world_size, stage=1) + run_hybrid_zero_optim_test(rank, world_size, stage=2) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_load_balance(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_load_balance(world_size=4) diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py new file mode 100644 index 000000000..fce0d1064 --- /dev/null +++ b/tests/test_moe/test_moe_router.py @@ -0,0 +1,41 @@ +import pytest +import torch + +from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter + + +@pytest.mark.parametrize(["router", "num_groups"], [ + (Top1Router(), 1), + (Top2Router(), 1), + (TopKRouter(num_selected_experts=3), 4), +]) +@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ + (4, 5, 8), + (3, 4, 4), +]) +def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): + x = torch.randn((batch_size * seq_len, num_experts)).cuda() + if num_groups > 1: + x = x.expand(num_groups, -1, -1) + + router.train() + if isinstance(router, TopKRouter): + combine_array, dispatch_mask = router(x, expert_capacity=2) + else: + combine_array, dispatch_mask = router(x) + assert combine_array.shape[:-1] == x.shape + assert dispatch_mask.shape[:-1] == x.shape + assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) + + router.eval() + if isinstance(router, TopKRouter): + combine_array, dispatch_mask = router(x, expert_capacity=2) + else: + combine_array, dispatch_mask = router(x) + assert combine_array.shape[:-1] == x.shape + assert dispatch_mask.shape[:-1] == x.shape + assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) + + +if __name__ == "__main__": + test_router_forward(Top1Router(), 4, 4, 4, 1) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py new file mode 100644 index 000000000..8f046ab00 --- /dev/null +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -0,0 +1,105 @@ +import pytest +import torch + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.manager import MOE_MANAGER +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel + + +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def run_zero_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + + zero_model = MoeModel() + optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + booster = Booster(plugin=plugin) + zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) + + torch_model = MoeModel() + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_model = torch_model.cuda() + grad_handler = MoeGradientHandler(torch_model) + + # assert zero model + for (torch_name, torch_param), (zero_name, zero_param) in zip( + torch_model.named_parameters(), zero_model.module.named_parameters() + ): + assert zero_name == torch_name + assert torch.allclose(zero_param.data, torch_param.data) + + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) + assert torch.allclose(torch_out, zero_out) + grad_handler.handle_gradient() + + for (zero_name, zero_param), (torch_name, torch_param) in zip( + zero_model.module.named_parameters(), torch_model.named_parameters() + ): + assert zero_name == torch_name + zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(zero_param, "moe_info"): + assert len(zero_grad_list) == 0 + assert torch.allclose(zero_param.grad, torch_param.grad) + else: + assert len(zero_grad_list) > 0 + torch_grad_list = split_ddp_grad(torch_param.grad, world_size) + if stage == 2: + torch_grad_list = torch_grad_list[local_rank : local_rank + 1] + assert len(zero_grad_list) == len(torch_grad_list) + for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): + assert torch.allclose(zero_grad, torch_grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + MOE_MANAGER.setup(seed=42, parallel="EP") + seed_all(42 + rank) + run_zero_test(rank, world_size, stage=1) + run_zero_test(rank, world_size, stage=2) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py deleted file mode 100644 index c48f9a355..000000000 --- a/tests/test_moe/test_moe_zero_init.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.logging import get_dist_logger -from colossalai.nn import CheckpointModule -from colossalai.nn.layer import MoeModule -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from tests.test_zero.test_legacy.common import CONFIG - - -class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): - class TestSubModule(CheckpointModule): - def __init__(self): - super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule( - dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict - ) - self.proj = nn.Linear(16, 4) - - def _forward(self, x): - x, y = self.moe(x) - x = self.proj(x) - return x, y - - super().__init__() - self.test_embed = nn.Linear(4, 16) - self.test_transform = TestSubModule() - - def forward(self, x): - MOE_CONTEXT.reset_loss() - - x = self.test_embed(x) - x, y = self.test_transform(x) - - MOE_CONTEXT.add_loss(y) - return x - - -@parameterize("init_device_type", ["cpu", "cuda"]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_moe_zero_init(init_device_type, shard_strategy_class): - get_dist_logger("test_moe_zero_init") - - if init_device_type == "cuda": - init_device = get_current_device() - elif init_device_type == "cpu": - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext( - target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor, - ): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert hasattr(param, "colo_attr") - - # the parameters in moe experts and its gate should not be sharded - if ("experts" in name) or ("gate" in name) or ("residual_combine" in name): - assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) - else: - assert param.colo_attr.sharded_data_tensor.is_sharded - - # the parameters in moe experts is not replicated - if "experts" in name: - assert not param.colo_attr.is_replicated - else: - assert param.colo_attr.is_replicated - - if param.colo_attr.param_is_sharded: - assert ( - param.colo_attr.data_payload.device.type == init_device.type - ), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}" - else: - assert param.colo_attr.data_payload.device.type == "cuda" - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - run_moe_zero_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_moe_zero_init(world_size): - spawn(_run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py deleted file mode 100644 index 724d70d77..000000000 --- a/tests/test_moe/test_moe_zero_model.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.legacy.engine.gradient_handler import MoeGradientHandler -from colossalai.nn import MoeLoss -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd - - -@parameterize("enable_autocast", [False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy_class): - shard_strategy = shard_strategy_class() - - get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") - _, train_dataloader, _, optimizer_class, _ = get_components_func() - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - - with ZeroInitContext( - target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True - ): - zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, enable_autocast) - run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) - grad_handler.handle_gradient() - - check_grads_padding(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index bb9822dae..ebea7509f 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -2,120 +2,91 @@ import pytest import torch import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.legacy.amp import convert_to_apex_amp -from colossalai.legacy.engine.gradient_handler import MoeGradientHandler -from colossalai.nn import MoeLoss -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.manager import MOE_MANAGER +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel -def _run_step(model, optimizer, data, label, criterion, grad_handler): +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): model.train() - optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - - loss = loss.float() - if isinstance(model, ShardedModelV2): + if isinstance(model, LowLevelZeroModel): optimizer.backward(loss) else: loss.backward() + return y - if grad_handler is not None: + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + + zero_model = MoeModel() + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + torch_model = MoeModel() + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + grad_handler = MoeGradientHandler(torch_model) + + for _ in range(2): + data = torch.randn(16, 4).cuda() / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + run_fwd_bwd(torch_model, data, label, criterion, None) + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) grad_handler.handle_gradient() - optimizer.step() + torch_optimizer.step() + zero_optimizer.step() + + for (torch_name, torch_param), (zero_name, zero_param) in zip( + torch_model.named_parameters(), zero_model.named_parameters() + ): + assert torch.allclose( + torch_param.data, zero_param.data + ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + + torch_optimizer.zero_grad() + zero_optimizer.zero_grad() -@parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug -@parameterize("reuse_fp16_shard", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2( - cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0 -): - shard_strategy = shard_strategy_class() - if use_cpuadam and cpu_offload is False: - return - MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") - _, train_dataloader, _, optimizer_class, _ = get_components_func() - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - - with ZeroInitContext( - target_device=torch.device("cpu") if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True, - ): - zero_model = MoeModel(checkpoint=True) - - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy="cpu" if cpu_offload else "cuda", - reuse_fp16_shard=reuse_fp16_shard, - ) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2( - zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio - ) - - amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - apex_grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) - _run_step(zero_model, sharded_optim, data, label, criterion, None) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + MOE_MANAGER.setup(seed=42, parallel="EP") + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - _run_test_sharded_optim_v2() - - -# use_cpuadam = True can be used with cpu_offload = False @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): - spawn(_run_dist, world_size) + spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_zero_optim(world_size=4) + test_moe_zero_optim(world_size=2)