From f71e63b0f39a108c512d81afd5272c41a977708c Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Wed, 8 Nov 2023 23:07:03 +0800 Subject: [PATCH] [moe] support optimizer checkpoint (#5015) * Refactor MoE Manager setup method * unshard optim ckpt * optim io * update transformer version * update requirements * update ckpt * update ckpt * update ckpt * fix engine * fix engine --- .../plugin/moe_hybrid_parallel_plugin.py | 9 +- .../inference/tensor_parallel/engine.py | 10 +- colossalai/moe/__init__.py | 4 +- colossalai/moe/checkpoint.py | 553 +++++++++++++++++- colossalai/moe/experts.py | 17 +- colossalai/moe/manager.py | 29 +- colossalai/tensor/moe_tensor/api.py | 13 + .../openmoe/benchmark/benchmark_cai.py | 4 +- .../openmoe/benchmark/benchmark_fsdp.py | 2 +- examples/language/openmoe/requirements.txt | 2 +- examples/language/openmoe/train.py | 4 +- tests/test_moe/test_grad_handler.py | 14 +- tests/test_moe/test_kernel.py | 22 +- tests/test_moe/test_moe_checkpoint.py | 171 ++++-- tests/test_moe/test_moe_ep_tp.py | 16 +- tests/test_moe/test_moe_group.py | 2 +- tests/test_moe/test_moe_hybrid_zero.py | 4 +- tests/test_moe/test_moe_load_balance.py | 8 +- tests/test_moe/test_moe_zero_fwd_bwd.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 2 +- 20 files changed, 738 insertions(+), 150 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 3f0e95d39..e976d0aaf 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ) from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MoeCheckpintIO +from colossalai.moe import MoECheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -322,8 +322,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): **_kwargs, ) - def get_checkpoint_io(self) -> MoeCheckpintIO: - self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + def get_checkpoint_io(self) -> MoECheckpintIO: + self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def configure( @@ -359,9 +359,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): max_norm=self.max_norm, **self.amp_config, ) - self.checkpoint_io.link_master_and_working_param( - optimizer.working_to_master_map, optimizer.master_to_working_map - ) else: optimizer = HybridParallelNaiveOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 2eadbcab1..2478b574d 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -79,13 +79,15 @@ class TPInferEngine: self.multi_query_group_num = model.config.num_attention_heads # default to attention_heads - self.multi_query_attention = model.config.multi_query_attention + if hasattr(model.config, "multi_query_attention"): + self.multi_query_attention = getattr(model.config, "multi_query_attention") if hasattr(model.config, "multi_query_group_num"): - self.multi_query_group_num = model.config.multi_query_group_num + self.multi_query_group_num = getattr(model.config, "multi_query_group_num") if hasattr(model.config, "num_key_value_heads"): - self.multi_query_group_num = model.config.num_key_value_heads + self.multi_query_group_num = getattr(model.config, "num_key_value_heads") + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -108,7 +110,7 @@ class TPInferEngine: assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads - if self.multi_query_attention: + if hasattr(self, "multi_query_attention"): # NOTE the logic of MQA tensor parallelism should be specified. assert ( self.multi_query_group_num % self.tp_size == 0 diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index f32e89dfa..721da69d0 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,4 +1,4 @@ -from .checkpoint import MoeCheckpintIO +from .checkpoint import MoECheckpintIO from .experts import MLPExperts from .layers import SparseMLP from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter @@ -13,5 +13,5 @@ __all__ = [ "NormalNoiseGenerator", "UniformNoiseGenerator", "SparseMLP", - "MoeCheckpintIO", + "MoECheckpintIO", ] diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 386fc2010..a8c50eab6 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -1,32 +1,46 @@ +import copy import logging import os -from copy import deepcopy from pathlib import Path -from typing import Iterator, Optional, OrderedDict, Tuple +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup -from torch.optim import Optimizer from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO from colossalai.checkpoint_io.utils import ( StateDictSharder, gather_distributed_param, get_model_base_filenames, + get_optimizer_base_filenames, is_safetensors_available, load_shard_state_dict, + load_state_dict, load_state_dict_into_model, + load_states_into_optimizer, save_config_file, + save_param_groups, + save_state_dict, save_state_dict_shards, + sharded_optimizer_loading_epilogue, ) +from colossalai.interface import OptimizerWrapper from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor +from colossalai.tensor.moe_tensor.api import ( + get_dp_group, + get_dp_rank, + get_dp_size, + get_ep_group, + get_ep_rank, + get_ep_size, + is_moe_tensor, +) -class MoeCheckpintIO(HybridParallelCheckpointIO): - +class MoECheckpintIO(HybridParallelCheckpointIO): def __init__( self, dp_group: ProcessGroup, @@ -55,7 +69,7 @@ class MoeCheckpintIO(HybridParallelCheckpointIO): ep_size = get_ep_size(model_param) expert_num = param.shape[0] // ep_size assert param.shape[0] % ep_size == 0 - param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] + param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num] state_dict[name] = param dist.barrier() return state_dict @@ -156,7 +170,7 @@ class MoeCheckpintIO(HybridParallelCheckpointIO): dp_rank = get_dp_rank(param) if dp_rank == 0: param = param.data.cuda() - all_param = [deepcopy(param) for _ in range(ep_size)] + all_param = [torch.zeros_like(param) for _ in range(ep_size)] # gather param from every ep rank dist.all_gather(all_param, param, group=ep_group) if ep_rank == 0: @@ -245,30 +259,523 @@ class MoeCheckpintIO(HybridParallelCheckpointIO): index_file.write_index_file(save_index_file) save_config_file(model, checkpoint) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) dist.barrier() # ======================================================== # Abstract methods for optimizer loading/saving implementation # ======================================================== - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): - raise NotImplementedError() + def pre_load_optim( + self, + state: OrderedDict, + working_param, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + ) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - raise NotImplementedError() + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + is_moe_tensor_flag = is_moe_tensor(working_param) + if is_moe_tensor_flag: + ep_rank = get_ep_rank(working_param) + ep_size = get_ep_size(working_param) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + if is_moe_tensor_flag: + with torch.no_grad(): + expert_num = v.shape[0] // ep_size + assert v.shape[0] % ep_size == 0 + v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num] + else: + # Shard state along data parallel group when using Zero. + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + # ep extra group + if MOE_MANAGER.parallel == "EP": + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1][ + "params" + ] # Only keep the parameters kept by current pipeline stage. + for param in new_pg["params"]: + param.data = param.data.to(torch.float32) + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + param, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + dist.barrier() + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from a file with given path. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + """ + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + if id(working_param) in optimizer.param_info["param2id"]: + return optimizer.param_info["param2id"][id(working_param)] + else: + None + + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + # Complete optimizer state_dict loaded from checkpoint, need to be processed later. + state_dict = load_state_dict(checkpoint) + + # Load param_groups. + updated_groups = [] + saved_groups = state_dict["param_groups"] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage. + updated_groups.append(new_pg) + # ep extra group + if MOE_MANAGER.parallel == "EP": + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1][ + "params" + ] # Only keep the parameters kept by current pipeline stage. + for param in new_pg["params"]: + param.data = param.data.to(torch.float32) + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. First discard those states not belonging to current pipeline stage. + master_to_working_map = optimizer.get_master_to_working_map() + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id is not None: + id_map[param_id] = param + load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + if param is None: + continue + device = param.device + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + param, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + ) + optimizer.optim.state[param] = sharded_state + sharded_optimizer_loading_epilogue(optimizer.optim) + dist.barrier() + + def pre_save_optim( + self, + state: OrderedDict, + param: torch.Tensor, + inplace: bool, + device: torch.device = torch.device("cpu"), + ) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + if is_moe_tensor(param): + moe_dp_group = get_dp_group(param) + moe_dp_size = get_dp_size(param) + moe_ep_group = get_ep_group(param) + moe_ep_size = get_ep_size(param) + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # moe param + if is_moe_tensor(param): + # dp gather + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)] + dist.all_gather(gather_tensor, v, group=moe_dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + # ep gather + gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)] + dist.all_gather(gather_tensor, v, group=moe_ep_group) + v = torch.cat(gather_tensor, dim=0) + else: + # global dp + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))] + dist.all_gather(gather_tensor, v, group=self.dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + + state_[k] = v.detach().clone().to(device) + + return state_ + + def _optimizer_sharder( + self, + optimizer: OptimizerWrapper, + size_per_shard: int = 1024, + ): + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info["param2id"][id(working_param)] + state_ = self.pre_save_optim( + state, + working_param, + inplace=False, + device=torch.device("cuda"), + ) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def save_sharded_optimizer( self, - optimizer: Optimizer, - checkpoint: Path, - gather_dtensor: bool, - prefix: str, - size_per_shard: int, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, ): - raise NotImplementedError() + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): - raise NotImplementedError() + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = self._optimizer_sharder( + optimizer, + size_per_shard=size_per_shard, + ) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.dp_rank == 0 and self.tp_rank == 0 + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) + + if control_saving: + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer state dict to a file with given path. + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. + checkpoint (str): Path to save optimizer state_dict. + gather_dtensor (bool): Whether to gather_dtensor, not used. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + + # optimizer states of parameters kept by local device('s pipeline stage) + local_states = dict() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + # working param is needed for obtaining correct param_id + master_to_working_map = optimizer.get_master_to_working_map() + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + # gather complete state from tp shards & dp shards + param_id = optimizer.param_info["param2id"][id(working_param)] + local_states[param_id] = self.pre_save_optim( + state, + working_param, + inplace=False, + device=torch.device("cuda"), + ) + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states} + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + states_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(states_list, local_states, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()} + for _states in states_list: + state_dict["state"].update(_states) + save_state_dict(state_dict, checkpoint, use_safetensors=False) + dist.barrier() diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 3471b2876..477b76547 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info +from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -53,7 +53,8 @@ class MLPExperts(nn.Module): # get expert parallel info if expert_parallel is not None: self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( - num_experts, use_tp=True if expert_parallel == "TP" else False) + num_experts, use_tp=True if expert_parallel == "TP" else False + ) # get settings for different parallel self.ep_size = get_ep_size(self) if expert_parallel == "TP": @@ -87,7 +88,7 @@ class MLPExperts(nn.Module): def reset_parameters(self): # expert param should be different if self.expert_parallel is not None: - seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) + seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True) else: seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) with seed_ctx: @@ -99,10 +100,10 @@ class MLPExperts(nn.Module): torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) def forward( - self, - x: torch.Tensor, - param_slice: Tuple[slice] = (slice(None),), - use_sparse: bool = True, + self, + x: torch.Tensor, + param_slice: Tuple[slice] = (slice(None),), + use_sparse: bool = True, ) -> torch.Tensor: """ forward: hidden_size --> intermediate_size --> hidden_size @@ -129,7 +130,7 @@ class MLPExperts(nn.Module): mask = torch.sum(mask, dim=-1) x_list = [] for i in range(e): - x_list.append(x[i, :mask[i]]) + x_list.append(x[i, : mask[i]]) x = x_list if self.gated: diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index f237ea134..3e64d796c 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -8,14 +8,13 @@ from colossalai.tensor.moe_tensor.api import get_moe_info from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo -class MoeManager(metaclass=SingletonMeta): +class MoEManager(metaclass=SingletonMeta): """MoE manager. This class manages different parallel groups in MoE context and MoE loss in training. """ def __init__(self): self.parallel = None - self.seed = None self.mode = None self.use_ep_inside = None self.world_size = None @@ -48,7 +47,6 @@ class MoeManager(metaclass=SingletonMeta): def setup( self, - seed: int, parallel: str = None, mode: str = "dynamic", max_ep_size: int = 8, @@ -73,10 +71,9 @@ class MoeManager(metaclass=SingletonMeta): fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ - assert (not self.is_initialized), "MoE distributed context shouldn't be set up again" + assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" - self.seed = seed + dist.get_rank() self.parallel = parallel self.use_ep_inside = use_ep_inside self.world_size = dist.get_world_size() @@ -87,10 +84,12 @@ class MoeManager(metaclass=SingletonMeta): if self.mode == "dynamic": self.max_ep_size = min(max_ep_size, self.world_size) else: - assert (fixed_dp_size > 0 and fixed_ep_size > 0 - and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0" - assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) - and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int" + assert ( + fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0 + ), "dp_size, ep_size and pp_size should be greater than 0" + assert ( + isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int) + ), "dp_size, ep_size and pp_size should be int" self.ep_size = fixed_ep_size self.dp_size = fixed_dp_size self.pp_size = fixed_pp_size @@ -112,10 +111,12 @@ class MoeManager(metaclass=SingletonMeta): """ if self.mode == "dynamic": - gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater - lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less - assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" - " is not a multiple of ep size or vice versa.") + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + assert gt_flag or lt_flag, ( + "Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa." + ) dp_size = 1 if gt_flag else self.world_size // num_experts ep_size = min(self.world_size // dp_size, self.max_ep_size) dp_size = self.world_size // ep_size @@ -159,4 +160,4 @@ class MoeManager(metaclass=SingletonMeta): return self.parallel -MOE_MANAGER = MoeManager() +MOE_MANAGER = MoEManager() diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index c9efec63f..c452f0d63 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -72,6 +72,19 @@ def get_ep_size(tensor: torch.Tensor) -> int: return tensor.moe_info.ep_size +def get_dp_size(tensor: torch.Tensor) -> int: + """ + Get the data parallel size of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel size of the given tensor. + """ + return tensor.moe_info.dp_size + + def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: """ Get the data parallel group of the given tensor. diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 112a12cb6..f48ba9ef8 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -155,9 +155,7 @@ def main(): "precision": "bf16", "zero_stage": args.zero_stage, } - mgr_dict = { - "seed": 42, - } + mgr_dict = {} if args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 45a11ad63..7f438fc5a 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -41,7 +41,7 @@ def fsdp_main(rank, world_size, args): # initialize the process group dist.init_process_group("nccl") - MOE_MANAGER.setup(seed=42, parallel=None) + MOE_MANAGER.setup(parallel=None) dp_size = dist.get_world_size() dataset = RandomDataset( diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index ccf02ba1d..6b9f80711 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -1,5 +1,5 @@ colossalai >= 0.3.3 torch >= 1.8.1 -transformers >= 4.20.0 +transformers >= 4.20.0, <= 4.34.0 sentencepiece datasets diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index e8c2f6aaa..b4c45416c 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -213,9 +213,7 @@ def main(): "precision": args.precision, "zero_stage": args.zero_stage, } - mgr_dict = { - "seed": 42, - } + mgr_dict = {} if args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 28ee618e1..3fac62472 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -6,10 +6,9 @@ import torch.nn as nn import colossalai from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group +from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 DIM = 16 @@ -25,7 +24,7 @@ def run_test(rank, world_size, port): backend="nccl", ) - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization + MOE_MANAGER.setup(parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: @@ -41,15 +40,6 @@ def run_test(rank, world_size, port): model = nn.ModuleList(layer_list) model = model.to(get_current_device()) dist_dict = MOE_MANAGER.parallel_info_dict - assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) - assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) - assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) - assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) - assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) - assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) - - sync_moe_model_param(model) - assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index c710c7bf7..255ec7444 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -20,21 +20,23 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") local_rank = dist.get_rank() - MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization + MOE_MANAGER.setup(parallel="EP") # MOE environment initialization MOE_MANAGER.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed + torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) - layer = SparseMLP(hidden_size=hidden_size, - intermediate_size=hidden_size * 2, - num_experts=NUM_EXPERTS, - router_top_k=topk, - router_capacity_factor_train=1.0) + layer = SparseMLP( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_experts=NUM_EXPERTS, + router_top_k=topk, + router_capacity_factor_train=1.0, + ) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -55,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.enable_kernel = True - new_out = layer(tokens) # get outputs through colossal kernel + new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) @@ -90,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, topk): spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_kernel(2, 256, torch.float16, 2) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index b68eaec50..bd1103df3 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -12,53 +12,112 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device -sys.path.append(os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - "examples/language/openmoe", -)) +sys.path.append( + os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "examples/language/openmoe", + ) +) OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy +def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): + input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": input_ids, + } + + +def run_fwd_bwd( + model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None +): + model.train() + if pipeline: + train_dataloader_iter = DummyDataloader(data_gen_fn, length=1) + is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() + y = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = y["loss"] + else: + if criterion: + y = model(data).logits + loss = criterion(y) + else: + loss = model(data, label) + loss = loss.float() + + if optimizer is not None: + optimizer.backward(loss) + else: + loss.backward() + return y + + def get_config(): config = LlamaConfig( vocab_size=300, hidden_size=16, intermediate_size=32, - num_hidden_layers=4, + num_hidden_layers=2, num_attention_heads=2, head_dim=4, dropout_rate=0.0, hidden_act="swiglu", ) - set_openmoe_args(config, num_experts=16, moe_layer_interval=1) + set_openmoe_args(config, num_experts=8, moe_layer_interval=1) return config def get_model(parallel): config = get_config() model = OpenMoeForCausalLM(config) + optim = torch.optim.Adam(model.parameters()) if parallel == None: plugin = MoeHybridParallelPlugin( - tp_size=1, - pp_size=1, - zero_stage=0, - custom_policy=OpenMoeForCausalLMPolicy(), - ) - elif parallel == "zero_ep": - plugin = MoeHybridParallelPlugin( + precision="bf16", tp_size=1, pp_size=1, zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) + elif parallel == "ep": + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "ep_zero": + plugin = MoeHybridParallelPlugin( + precision="bf16", + tp_size=1, + pp_size=1, + zero_stage=2, + extra_dp_size=2, + custom_policy=OpenMoeForCausalLMPolicy(), + ) elif parallel == "hybrid": plugin = MoeHybridParallelPlugin( + precision="bf16", tp_size=1, pp_size=2, zero_stage=1, @@ -66,54 +125,77 @@ def get_model(parallel): custom_policy=OpenMoeForCausalLMPolicy(), ) booster = Booster(plugin=plugin) - model, _, _, _, _ = booster.boost(model=model) - return model, booster + model, optim, _, _, _ = booster.boost(model=model, optimizer=optim) + return model, booster, optim -def _test_moe_checkpoint(parallel, shard): +def _test_moe_checkpoint(rank, parallel): if parallel == None: MOE_MANAGER.setup( - seed=42, parallel=None, ) - elif parallel == "zero2_ep": + elif parallel == "ep": MOE_MANAGER.setup( - seed=42, parallel="EP", ) + elif parallel == "ep_zero": + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=2, + ) elif parallel == "hybrid": MOE_MANAGER.setup( - seed=42, parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=2, fixed_pp_size=2, ) - model1, booster1 = get_model(parallel) - model2, booster2 = get_model(parallel) + model1, booster1, optim1 = get_model(parallel) + model2, booster2, optim2 = get_model(parallel) + model3, booster3, optim3 = get_model(parallel) - if shard: - booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1) - booster2.load_model(model2, "./tmp_ckpt") + # param ckpt + # shard + booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1) + booster2.load_model(model2, "./tmp_ckpt1") + # unshard + booster1.save_model(model1, "./tmp_ckpt1.pth") + booster3.load_model(model3, "./tmp_ckpt1.pth") + # check + check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) + check_state_dict_equal(model1.state_dict(), model3.state_dict(), False) + + # optim ckpt + criterion = lambda x: x.mean() + data = torch.randint(0, 4, (2, 4)).cuda() + label = torch.randint(0, 4, (2,)).cuda() + if parallel == "hybrid": + kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin} else: - booster1.save_model(model1, "tmp_ckpt.pth") - booster2.load_model(model2, "tmp_ckpt.pth") - - state1 = model1.state_dict() - state2 = model2.state_dict() - for k, v in state1.items(): - u = state2.get(k) - assert torch.equal(u.data, v.data) + kwargs = {} + run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs) + optim1.step() + optim1.zero_grad() + # shard + booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1) + dist.barrier() + booster2.load_optimizer(optim2, "./tmp_ckpt2") + # unshard + booster1.save_optimizer(optim1, "./tmp_ckpt2.pth") + booster3.load_optimizer(optim3, "./tmp_ckpt2.pth") + # check + check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False) + check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False) if dist.get_rank() == 0: - if shard: - shutil.rmtree("./tmp_ckpt") - else: - os.remove("tmp_ckpt.pth") + shutil.rmtree("./tmp_ckpt1") + shutil.rmtree("./tmp_ckpt2") + os.remove("./tmp_ckpt1.pth") + os.remove("./tmp_ckpt2.pth") -def _run_dist(rank, world_size, port, parallel, shard): +def _run_dist(rank, world_size, port, parallel): colossalai.launch( config=dict(), rank=rank, @@ -122,17 +204,16 @@ def _run_dist(rank, world_size, port, parallel, shard): port=port, backend="nccl", ) - _test_moe_checkpoint(parallel, shard) + _test_moe_checkpoint(rank, parallel) @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) -@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"]) -@pytest.mark.parametrize("shard", [True, False]) +@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) @rerun_if_address_is_in_use() -def test_moe_checkpoint(world_size, parallel, shard): - spawn(_run_dist, world_size, parallel=parallel, shard=shard) +def test_moe_checkpoint(world_size, parallel): + spawn(_run_dist, world_size, parallel=parallel) if __name__ == "__main__": - test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True) + test_moe_checkpoint(world_size=4, parallel="hybrid") diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 11d0664fd..2c9bbd446 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -14,16 +14,16 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, syn def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): assert batch_size % world_size == 0 - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed, parallel=None) + MOE_MANAGER.setup(parallel=None) local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed, parallel="EP") + MOE_MANAGER.setup(parallel="EP") ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed, parallel="TP") + MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) @@ -44,7 +44,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size torch.cuda.manual_seed(seed) tp_data = torch.randn(batch_size, dim, device=get_current_device()) micro_batch_size = batch_size // world_size - ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)] + ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)] out_local = local_model(tp_data) MOE_MANAGER.reset_loss() @@ -52,8 +52,8 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)]) - assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)]) + assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)]) out_local.mean().backward() out_tp.mean().backward() @@ -77,5 +77,5 @@ def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 3cd5acc0d..95c0e715d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -15,7 +15,7 @@ INTERMEDIATE_SIZE = 8 def run_moe_init(expert_parallel): MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, parallel=expert_parallel) + MOE_MANAGER.setup(parallel=expert_parallel) expert_args = dict( hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py index e9f71d5ca..7ada4090f 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -35,13 +35,13 @@ def run_zero_optim_test(local_rank, world_size, stage=1): label = torch.randint(0, 4, (16,)).cuda() MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, parallel=None) + MOE_MANAGER.setup(parallel=None) torch_model = MoeModel() torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_model = torch_model.cuda() MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") + MOE_MANAGER.setup(max_ep_size=2, use_ep_inside=False, parallel="EP") zero_model = MoeModel() extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 173a7a356..717bb99fb 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -45,7 +45,6 @@ def run_zero_optim_test(local_rank, world_size, stage=1): MOE_MANAGER.__init__() MOE_MANAGER.setup( - seed=42, parallel="EP", ) zero_model = MoeModel(enable_load_balance=True) @@ -55,7 +54,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(parallel="EP") torch_model = MoeModel() for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) @@ -94,7 +93,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) zero_optimizer.step() zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" + assert torch.allclose(zero_out, torch_out, atol=3e-5), f"zero_out:{zero_out}\ntorch_out{torch_out}" def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): @@ -103,14 +102,13 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): label = torch.randint(0, 4, (16,)).cuda() MOE_MANAGER.__init__() - MOE_MANAGER.setup(seed=42, parallel=None) + MOE_MANAGER.setup(parallel=None) torch_model = MoeModel() torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_model = torch_model.cuda() MOE_MANAGER.__init__() MOE_MANAGER.setup( - seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP", diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 8f046ab00..f0795a4c7 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -88,7 +88,7 @@ def run_zero_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank) run_zero_test(rank, world_size, stage=1) run_zero_test(rank, world_size, stage=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index ebea7509f..0d2e2fb1b 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -76,7 +76,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(parallel="EP") run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2)