diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md new file mode 100644 index 000000000..be50a8f9f Binary files /dev/null and b/applications/ColossalMoE/README.md differ diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py new file mode 100644 index 000000000..ddef565c5 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -0,0 +1,205 @@ +import logging +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.checkpoint_io import CheckpointIndexFile +from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model +from colossalai.moe import MoECheckpintIO +from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor + + +class MixtralMoECheckpointIO(MoECheckpintIO): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.no_grad() + def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: + """ + Preprocess state_dict before loading and slice the state_dict of MOE tensors. + """ + model_param_dict = dict(model.named_parameters()) + for name, param in list(state_dict.items()): + if ".gate.weight" in name: + new_name = "module." + name.replace(".gate.weight", ".gate_weight") + state_dict[new_name] = state_dict.pop(name) + elif ".experts." in name: + # if is moe tensor + # in our moe module, expert is cat as one tensor + # but mixtral's experts is not cat + # we will insert the loaded expert into the position of cat tensor + + # get model param + str_idx = name.index(".experts.") + expert_idx = int(name.split(".")[-3]) + if ".w1." in name: + model_param_name = name.replace(name[str_idx:], ".experts.wi_gate") + elif ".w2." in name: + model_param_name = name.replace(name[str_idx:], ".experts.wo") + elif ".w3." in name: + model_param_name = name.replace(name[str_idx:], ".experts.wi_up") + model_param_name = "module." + model_param_name + # skip for pipeline + if model_param_name not in model_param_dict: + continue + model_param = model_param_dict[model_param_name] + assert is_moe_tensor(model_param) + # get expert range + ep_rank = get_ep_rank(model_param) + ep_size = get_ep_size(model_param) + expert_num = 8 // ep_size + expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num)) + # insert new param + if expert_idx in expert_range: + new_param = model_param + new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1) + state_dict[model_param_name] = new_param + state_dict.pop(name) + else: + new_name = "module." + name + state_dict[new_name] = state_dict.pop(name) + + dist.barrier() + return state_dict + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + state_dict = self.pre_load_model(model, state_dict) + missing_keys = [] + + load_state_dict_into_model( + model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True, + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + name = name.replace("module.", "") + name = name.replace(".gate_weight", ".gate.weight") + if ".experts.wi_gate" in name: + for i in range(8): + new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") + _load(new_name) + elif ".experts.wi_up" in name: + for i in range(8): + new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") + _load(new_name) + elif ".experts.wo" in name: + for i in range(8): + new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") + _load(new_name) + else: + _load(name) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + @torch.no_grad() + def pre_save_model(self, model: nn.Module) -> dict: + torch.cuda.empty_cache() + state_dict = model.state_dict() + for name, param in list(model.named_parameters()): + if ".gate_weight" in name: + new_name = name.replace(".gate_weight", ".gate.weight") + state_dict[new_name] = state_dict.pop(name).cpu() + elif ".experts." in name: + ep_group = get_ep_group(param) + ep_rank = get_ep_rank(param) + ep_size = get_ep_size(param) + dp_rank = get_dp_rank(param) + + if dp_rank == 0: + param = param.data.cuda() + all_param = [torch.zeros_like(param) for _ in range(ep_size)] + # gather param from every ep rank + dist.all_gather(all_param, param, group=ep_group) + if ep_rank == 0: + all_param = torch.cat(all_param, dim=0) + assert all_param.shape[0] == 8 + for i in range(8): + if ".wi_gate" in name: + new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") + elif ".wi_up" in name: + new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") + elif ".wo" in name: + new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") + new_name = new_name.replace("module.", "") + new_param = all_param[i].transpose(-1, -2) + state_dict[new_name] = new_param.cpu() + state_dict.pop(name) + else: + state_dict[name] = param.cpu() + + for name, param in list(state_dict.items()): + new_name = name.replace("module.", "") + state_dict[new_name] = state_dict.pop(name) + + torch.cuda.empty_cache() + if self.pp_size > 1: + if self.dp_rank == 0: + # gather state_dict from every pp rank + # because ckpt is large, we split it into 10 parts + # and gather them one by one + new_state_dict = {} + state_dict_keys = list(state_dict.keys()) + gap_key_num = min(30, len(state_dict_keys)) + gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num + for i in range(gap_key_num): + cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys] + cur_state_dict = {} + for k in cur_keys: + cur_state_dict[k] = state_dict[k] + out = [None for _ in range(self.pp_size)] + dist.all_gather_object(out, cur_state_dict, group=self.pp_group) + if self.pp_rank == 0: + for o in out: + for k, v in o.items(): + new_state_dict[k] = v.cpu() + state_dict = new_state_dict + dist.barrier() + return state_dict diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py new file mode 100644 index 000000000..e395c8578 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock + +from colossalai.lazy import LazyInitContext +from colossalai.moe import SparseMLP + + +class MixtralSparseMLP: + r""" + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." + ) + + @staticmethod + def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module: + r""" + Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, + and optionally marking parameters for gradient aggregation. + + Args: + module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + + Returns: + nn.Module: Union[FastLayerNorm, FusedLayerNorm]. + + Raises: + AssertionError: If the provided module is not an instance of nn.LayerNorm. + """ + with torch.no_grad(): + LazyInitContext.materialize(module) + + # get the attributes of the module + moe_kwargs = dict( + num_experts=8, + hidden_size=module.hidden_dim, + intermediate_size=module.ffn_dim, + router_top_k=module.top_k, + router_norm=True, + router_loss=False, + # router_capacity_factor_train= + # router_capacity_factor_eval= + mlp_activation="silu", + mlp_gated=True, + # enable_load_balance= + # load_balance_tolerance= + # load_balance_beam_width= + # load_balance_group_swap_factor= + enable_kernel=enable_kernel, + # enable_comm_overlap= + # enable_hierarchical_comm= + return_gate_logits=True, + ) + dtype = module.gate.weight.dtype + device = module.gate.weight.device + sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device) + + return sparse_mlp + + +def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module: + """ + Reverse the replace layer operation + + Args: + module (torch.nn.Module): The object of layer to shard + """ + if isinstance(model, MixtralDecoderLayer): + model.block_sparse_moe = MixtralSparseMLP.from_native_module( + model.block_sparse_moe, enable_kernel=enable_kernel + ) + else: + for _, child in model.named_children(): + replace_moe_layer(child, enable_kernel) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py new file mode 100644 index 000000000..2f6021f2d --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -0,0 +1,543 @@ +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.models.mixtral.modeling_mixtral import ( + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MoeCausalLMOutputWithPast, + _prepare_4d_causal_attention_mask, + load_balancing_loss_func, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.shard import ShardConfig + +__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] + + +class MixtralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=MixtralModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in mixtral.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class MixtralModelPolicy(MixtralPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralModel, + new_forward=MixtralPipelineForwards.mixtral_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class MixtralForCausalLMPolicy(MixtralPolicy): + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralForCausalLM, + new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class MixtralPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def mixtral_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + output_router_logits, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + output_router_logits, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if output_router_logits and past_router_logits is not None: + all_router_logits = past_router_logits + all_router_logits + if stage_manager.is_last_stage(): + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + + @staticmethod + def mixtral_for_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = True, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MixtralPipelineForwards.mixtral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_logits=past_router_logits, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py new file mode 100644 index 000000000..70b827264 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/utils.py @@ -0,0 +1,102 @@ +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from huggingface_hub import snapshot_download +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +@torch.no_grad() +def load_model(ckpt_path: str, model, booster: Booster, optimizer=None): + # pytorch ckpt + if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")): + ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json") + # saved ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + # download + else: + ckpt_path = snapshot_download(ckpt_path) + booster.load_model(model, ckpt_path) + if optimizer is not None: + optimizer.sync_moe_master_param() + optimizer.update_master_params(model) + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py new file mode 100644 index 000000000..70ddff940 --- /dev/null +++ b/applications/ColossalMoE/infer.py @@ -0,0 +1,138 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO +from colossal_moe.models.mixtral_layer import replace_moe_layer +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from colossal_moe.utils import load_model +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.moe import MOE_MANAGER +from colossalai.moe.utils import skip_init +from colossalai.utils import get_current_device + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["ep"], + help="Parallel methos.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./outputs", + help="The path of your saved model after finetuning.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set plugin + booster_kwargs = {} + hybrid_dict = { + "tp_size": 1, + "custom_policy": MixtralForCausalLMPolicy(), + "enable_fused_normalization": args.use_layernorm_kernel, + "enable_jit_fused": args.use_kernel, + "precision": args.precision, + "checkpoint_io": MixtralMoECheckpointIO, + "zero_stage": 1, + } + mgr_dict = {} + if args.plugin == "ep": + dp_size = dist.get_world_size() + plugin = MoeHybridParallelPlugin( + pp_size=1, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size, + **mgr_dict, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build mixtral model + config = MixtralConfig.from_pretrained(args.model_name) + config.num_local_experts = 1 # dont change this. it will not affect model + with skip_init(): + model = MixtralForCausalLM(config) + model.num_experts = 8 + model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16) + model = model.to(get_current_device()) + coordinator.print_on_master(f"Finish init model with config:\n{config}") + + # Replace moe + with skip_init(): + replace_moe_layer(model) + model.eval() + coordinator.print_on_master(f"Finish replace moe module") + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, _, _, _, _ = booster.boost(model=model) + coordinator.print_on_master(f"Finish init booster") + + # load ckpt + load_model(args.model_name, model, booster) + coordinator.print_on_master(f"Finish load ckpt") + + text = ["Hello my name is", "1+1=?"] + tokenizer.pad_token = tokenizer.unk_token + inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) + outputs = model.module.generate(**inputs, max_new_tokens=20) + outputs = tokenizer.batch_decode(outputs)[0] + print(outputs) + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh new file mode 100644 index 000000000..0487fe9c1 --- /dev/null +++ b/applications/ColossalMoE/infer.sh @@ -0,0 +1,7 @@ +NUM_GPU=2 +MODEL="mistralai/Mixtral-8x7B-v0.1" + +# ep +torchrun --standalone --nproc_per_node $NUM_GPU infer.py \ + --model_name $MODEL \ + --plugin "ep" \ diff --git a/applications/ColossalMoE/requirements.txt b/applications/ColossalMoE/requirements.txt new file mode 100644 index 000000000..9a5738c41 --- /dev/null +++ b/applications/ColossalMoE/requirements.txt @@ -0,0 +1,5 @@ +colossalai >= 0.3.3 +torch >= 1.8.1 +transformers == 4.36.0 +sentencepiece +datasets diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py new file mode 100644 index 000000000..275f59e10 --- /dev/null +++ b/applications/ColossalMoE/setup.py @@ -0,0 +1,43 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def fetch_version(): + with open("version.txt", "r") as f: + return f.read().strip() + + +setup( + name="colossal_moe", + version=fetch_version(), + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI MoE", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py new file mode 100644 index 000000000..7c6012a70 --- /dev/null +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -0,0 +1,185 @@ +import os +import shutil + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO +from colossal_moe.models.mixtral_layer import replace_moe_layer +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.moe.manager import MOE_MANAGER +from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): + input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": input_ids, + } + + +def run_fwd_bwd( + model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None +): + model.train() + if pipeline: + train_dataloader_iter = DummyDataloader(data_gen_fn, length=1) + is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() + y = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = y["loss"] + else: + if criterion: + y = model(data).logits + loss = criterion(y) + else: + loss = model(data, label) + loss = loss.float() + + if optimizer is not None: + optimizer.backward(loss) + else: + loss.backward() + return y + + +def get_config(): + config = MixtralConfig( + vocab_size=300, + hidden_size=32, + intermediate_size=16, + num_hidden_layers=2, + dropout_rate=0.0, + ) + return config + + +def get_model(parallel): + config = get_config() + model = MixtralForCausalLM(config).to(torch.bfloat16) + replace_moe_layer(model) + optim = torch.optim.Adam(model.parameters()) + args = dict( + precision="bf16", + tp_size=1, + zero_stage=1, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoECheckpointIO, + ) + if parallel == "ep": + plugin = MoeHybridParallelPlugin( + pp_size=1, + **args, + ) + elif parallel == "hybrid": + plugin = MoeHybridParallelPlugin( + pp_size=2, + microbatch_size=1, + **args, + ) + booster = Booster(plugin=plugin) + model, optim, _, _, _ = booster.boost(model=model, optimizer=optim) + return model, booster, optim + + +def _test_moe_checkpoint(parallel): + if dist.get_rank() == 0: + if os.path.exists("./tmp_ckpt1"): + shutil.rmtree("./tmp_ckpt1") + if os.path.exists("./tmp_ckpt2"): + shutil.rmtree("./tmp_ckpt2") + dist.barrier() + + if parallel == None: + MOE_MANAGER.setup( + parallel=None, + ) + elif parallel == "ep": + MOE_MANAGER.setup( + parallel="EP", + ) + elif parallel == "hybrid": + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=1, + fixed_ep_size=2, + fixed_pp_size=2, + ) + model1, booster1, optim1 = get_model(parallel) + model2, booster2, optim2 = get_model(parallel) + # param ckpt + # check not equal + try: + check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) + raise AssertionError("state_dict should not be equal") + except: + pass + # shard + booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1) + booster2.load_model(model2, "./tmp_ckpt1") + # check + check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) + + # optim ckpt + criterion = lambda x: x.mean() + data = torch.randint(0, 4, (2, 4)).cuda() + label = torch.randint(0, 4, (2,)).cuda() + if parallel == "hybrid": + kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin} + else: + kwargs = {} + run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs) + optim1.step() + optim1.zero_grad() + # shard + booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1) + dist.barrier() + booster2.load_optimizer(optim2, "./tmp_ckpt2") + # check + check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False) + + if dist.get_rank() == 0: + shutil.rmtree("./tmp_ckpt1") + shutil.rmtree("./tmp_ckpt2") + + +def _run_dist(rank, world_size, port, parallel): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + _test_moe_checkpoint(parallel) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("parallel", ["ep", "hybrid"]) +@rerun_if_address_is_in_use() +def test_moe_checkpoint(world_size, parallel): + spawn(_run_dist, world_size, parallel=parallel) + + +if __name__ == "__main__": + test_moe_checkpoint(world_size=4, parallel="hybrid") diff --git a/applications/ColossalMoE/tests/test_moe_layer.py b/applications/ColossalMoE/tests/test_moe_layer.py new file mode 100644 index 000000000..8b090c427 --- /dev/null +++ b/applications/ColossalMoE/tests/test_moe_layer.py @@ -0,0 +1,31 @@ +import copy + +import torch +from colossal_moe.models.mixtral_layer import MixtralSparseMLP +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + +class Config: + def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.hidden_act = hidden_act + + +def test_moe_layer(): + config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu") + mistral_moe = MixtralSparseMoeBlock(config).cuda() + colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda() + + data = torch.randn(2, 8, 4).cuda() + mistral_output = mistral_moe(data)[0] + colossal_output = colossal_moe(data)[0] + assert torch.allclose( + mistral_output, colossal_output + ), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}" + + +if __name__ == "__main__": + test_moe_layer() diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py new file mode 100644 index 000000000..1d0441a5a --- /dev/null +++ b/applications/ColossalMoE/train.py @@ -0,0 +1,320 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO +from colossal_moe.models.mixtral_layer import replace_moe_layer +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.moe import MOE_MANAGER, apply_load_balance +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +@torch.no_grad() +def get_global_loss(loss, booster): + global_loss = loss.clone().detach() + dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) + global_loss.div_(booster.plugin.dp_size) + return global_loss + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["hybrid"], + help="Parallel methods.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./outputs", + help="The path of your saved model after finetuning.", + ) + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help=" The interval (steps) of saving checkpoints.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # optim + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + + # lr scheduler + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin") + parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") + parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + + # load balance + parser.add_argument( + "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable." + ) + parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.") + # communicate overlap + parser.add_argument( + "--comm_overlap", + action="store_true", + help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + ) + # hierarchical all-to-all + parser.add_argument( + "--hierarchical_alltoall", + action="store_true", + help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set plugin + booster_kwargs = {} + hybrid_dict = { + "tp_size": 1, + "custom_policy": MixtralForCausalLMPolicy(), + "enable_fused_normalization": args.use_layernorm_kernel, + "enable_jit_fused": args.use_kernel, + "precision": args.precision, + "zero_stage": args.zero_stage, + "checkpoint_io": MixtralMoECheckpointIO, + } + mgr_dict = {} + if args.plugin == "hybrid": + plugin = MoeHybridParallelPlugin( + pp_size=args.pp_size, + microbatch_size=args.microbatch_size, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + **mgr_dict, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build Mixtral model + config = MixtralConfig.from_pretrained(args.model_name) + config.use_cache = False + config.num_local_experts = 1 + model = MixtralForCausalLM(config) + model.num_experts = 8 + model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16) + model = model.to(get_current_device()) + replace_moe_layer(model, enable_kernel=args.use_kernel) + coordinator.print_on_master(f"Finish init model with config:\n{config}") + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + dataset = RandomDataset(num_samples=100, tokenizer=tokenizer) + collate_fn = None + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Set optimizer + optimizer = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # Set lr scheduler + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + coordinator.print_on_master(f"Finish init booster") + + # Load ckpt + if args.load_checkpoint is None: + load_model(args.model_name, model, booster, optimizer) + coordinator.print_on_master(f"Finish load checkpoint") + else: + load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) + coordinator.print_on_master(f"Finish load optimizer") + + # Start finetuning + coordinator.print_on_master(f"Start finetuning") + for epoch in range(args.num_epoch): + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage, + ) as pbar: + for step in pbar: + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + global_loss = get_global_loss(loss, booster) + if coordinator._local_rank == "0": + pbar.set_postfix({"Loss": global_loss.item()}) + else: + # Forward pass + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Apply load balance + if ( + args.load_balance + and args.load_balance_interval > 0 + and (step + 1) % args.load_balance_interval == 0 + ): + coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) + # save ckeckpoint + if (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + save_checkpoint( + args.output_path, + booster, + model, + optimizer, + lr_scheduler, + epoch, + step, + args.batch_size, + coordinator, + ) + + # save checkpoint at the end of each epochs + booster.save_model(model, args.output_path, shard=True, size_per_shard=5120) + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + + # Finish training + coordinator.print_on_master(f"Finish training") + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/train.sh b/applications/ColossalMoE/train.sh new file mode 100644 index 000000000..bee7f5c8f --- /dev/null +++ b/applications/ColossalMoE/train.sh @@ -0,0 +1,19 @@ +NUM_GPU=8 +MODEL="mistralai/Mixtral-8x7B-v0.1" +SEQ_LENGTH=2048 +BATCH_SIZE=1 +LR=0.00001 + +# hybrid +# torchrun --standalone --nproc_per_node $NUM_GPU \ +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \ + train.py \ + --num_epoch 1 \ + --model_name $MODEL \ + --plugin "hybrid" \ + --batch_size $BATCH_SIZE \ + --lr $LR \ + --zero_stage 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 8 \ diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt new file mode 100644 index 000000000..3eefcb9dd --- /dev/null +++ b/applications/ColossalMoE/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index e976d0aaf..07cbc14a7 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_communication: bool = True, use_ep_inside: bool = True, custom_policy: Policy = None, + checkpoint_io: Optional[MoECheckpintIO] = None, ) -> None: assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + self.checkpoint_io = checkpoint_io # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) @@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) def get_checkpoint_io(self) -> MoECheckpintIO: - self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + if self.checkpoint_io is None: + self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + else: + self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def configure( diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 721da69d0..6dd0a5fc3 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,6 +1,7 @@ from .checkpoint import MoECheckpintIO from .experts import MLPExperts -from .layers import SparseMLP +from .layers import SparseMLP, apply_load_balance +from .manager import MOE_MANAGER from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator @@ -14,4 +15,6 @@ __all__ = [ "UniformNoiseGenerator", "SparseMLP", "MoECheckpintIO", + "MOE_MANAGER", + "apply_load_balance", ] diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index a8c50eab6..b37ffabea 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ + torch.cuda.empty_cache() if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): f"index located at {save_index_file}." ) dist.barrier() + torch.cuda.empty_cache() # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO): assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None ): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param return optimizer.param_info["param2id"][id(working_param)] @@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) id_map[param_id] = param # Read checkpoint index file. @@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO): new_pg = copy.deepcopy(saved_pg) new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - # ep extra group - if MOE_MANAGER.parallel == "EP": + # ep param group + if len(optimizer.optim.param_groups) > len(saved_groups): new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1][ - "params" - ] # Only keep the parameters kept by current pipeline stage. - for param in new_pg["params"]: - param.data = param.data.to(torch.float32) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) @@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + + # Then shard the loaded optimizer states if using tp/zero. + for pid, state in list(state_dict.items()): + if pid in id_map: + param = id_map[pid] + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + elif ( + hasattr(optimizer, "moe_master_to_working_map") + and id(param) in optimizer.moe_master_to_working_map + ): + working_param = optimizer.moe_master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + working_param, + current_shape=working_param.shape, + original_shape=original_shape, + device="cpu", + inplace=True, + ) + state_dict[pid] = sharded_state + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - param, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True, - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param @@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + torch.cuda.empty_cache() assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}." ) + torch.cuda.empty_cache() def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 477b76547..8e6ea3884 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -67,7 +67,11 @@ class MLPExperts(nn.Module): self.ep_size = 1 if gated: - self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_gate = nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size + ) + ) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) else: self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index b768fb94a..2ac5b186d 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,6 +51,8 @@ class SparseMLP(nn.Module): hidden_size: int, intermediate_size: int, router_top_k: int = 1, + router_loss: bool = True, + router_norm: bool = False, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, @@ -65,15 +67,19 @@ class SparseMLP(nn.Module): enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_comm: bool = False, + return_gate_logits: bool = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts self.gated = mlp_gated + self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() + self.router_loss = router_loss + self.router_norm = router_norm # moe router noisy_func = get_noise_generator(router_noisy_policy, num_experts) @@ -150,9 +156,8 @@ class SparseMLP(nn.Module): tokens = inputs.reshape(-1, self.hidden_size) # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) + gate_logits = F.linear(tokens, self.gate_weight) + gate_output = gate_logits.to(torch.float) # update expert load if self.enable_load_balance == True: @@ -165,7 +170,12 @@ class SparseMLP(nn.Module): # the result from the router used_capacity, *route_result_list = self.router( - inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + inputs=gate_output, + use_kernel=self.enable_kernel, + ep_group=self.ep_group, + use_loss=self.router_loss, + use_norm=self.router_norm, + ) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: @@ -177,22 +187,15 @@ class SparseMLP(nn.Module): # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n" - "Please use Experts build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n" "Please use Experts build function." + ) if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) @@ -204,7 +207,11 @@ class SparseMLP(nn.Module): ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - return ans + + if self.return_gate_logits: + return ans, gate_logits + else: + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) @@ -212,10 +219,7 @@ class SparseMLP(nn.Module): return expert_out def _ep_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ Expert Parallel @@ -228,10 +232,14 @@ class SparseMLP(nn.Module): """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_input = HierarchicalAllToAll.apply( + dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank + ) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_output = HierarchicalAllToAll.apply( + expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank + ) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] @@ -249,7 +257,7 @@ class SparseMLP(nn.Module): NUM_CHUNK = 4 NUM_STAGES = 4 - assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" + assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -262,13 +270,15 @@ class SparseMLP(nn.Module): for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() - output[:, :, offset:offset + chunk_size, :] = expert_out.data + output[:, :, offset : offset + chunk_size, :] = expert_out.data offset += chunk_size expert_out = None # all2all last output if _expert_out is not None: - expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) _expert_out = None # all2all next input @@ -288,10 +298,7 @@ class SparseMLP(nn.Module): return output def _tp_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ without overlap: @@ -326,8 +333,9 @@ class SparseMLP(nn.Module): NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert ( + dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index f5815d05d..5c7d06656 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -150,7 +150,14 @@ class Top1Router(MoeRouter): high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_loss: bool = False, + use_norm: bool = False, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -207,7 +214,7 @@ class Top1Router(MoeRouter): weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask + return used_capacity, combine_weights, sec_mask, probs class Top2Router(MoeRouter): @@ -240,7 +247,14 @@ class Top2Router(MoeRouter): drop_tks=drop_tks, ) - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_norm: bool = False, + use_loss: bool = True, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -257,6 +271,10 @@ class Top2Router(MoeRouter): assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) + if use_norm: + routing_weights, _ = torch.topk(probs, 2, dim=-1) + probs = probs / routing_weights.sum(dim=-1, keepdim=True) + num_experts = probs.size(-1) capacity = self.get_capacity(inputs.shape) @@ -270,10 +288,11 @@ class Top2Router(MoeRouter): cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() + if use_loss: + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index e25e7dd48..c642f1a44 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable: return torch.nn.GELU() elif act == "swiglu": return SwiGLU + elif act == "silu": + return torch.nn.SiLU() else: raise NotImplementedError("Unsupported activation function") diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e01c852be..47bc7603a 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # because they have different parallel strategy # so we need to store them separately in param_groups # instead of working_groups - moe_params = list() + self.working_moe_params = list() # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): - moe_params.append(param) + self.working_moe_params.append(param) continue group_params.append(param) @@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in additional group in optim - if len(moe_params) > 0: + # if there are moe params, store in addtional group in optim + if len(self.working_moe_params) > 0: + self._sync_master_param = False param_group = dict() + # create fp32 master param for key, value in self.optim.param_groups[0].items(): if key != "params": param_group[key] = value - param_group["params"] = moe_params + self.master_moe_params = [] + for param in self.working_moe_params: + self.master_moe_params.append(param.clone().to(torch.float32).detach()) + # create mapping from master to working for optimizer io + self.moe_master_to_working_map = {} + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param + # add to optim + param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) # initialize communication stream for @@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # update the params in the optimizer self.optim.param_groups[group_id]["params"] = real_master_params[group_id] + # update param for moe ep + # move grad to master param and compute norm + if len(self.working_moe_params) > 0: + moe_grads = [] + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + if master_moe_param.grad is not None: + raise RuntimeError("Moe param should not have grad here") + grad = working_moe_param.grad + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) + master_moe_param.grad = grad + working_moe_param.grad = None + moe_grads.append(grad) + grad_partition_groups.append(grad) + norm_group = self._compute_grad_norm(gradients=moe_grads) + norm_groups.append(norm_group) + self.optim.param_groups[-1]["params"] = self.master_moe_params + del moe_grads + # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) - # TODO: we should store master param for ep - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.data = param.data.to(torch.float32) - param.grad = param.grad.to(torch.float32) - # update the parameters self.optim.step() - # release the moe gradm - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.grad = None - param.data = param.data.to(self._dtype) + # release moe grad + if len(self.working_moe_params) > 0: + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.grad = None + working_moe_param.data = ( + master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() + ) # release the grad grad_partition_groups = [] @@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + def sync_moe_master_param(self): + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach() + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 721a4796a..17b790e3e 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,13 +1,22 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size + + +def delete_moe_info(model): + for _, param in model.named_parameters(): + if hasattr(param, "moe_info"): + delattr(param, "moe_info") class MoeModel(nn.Module): @@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert not torch.allclose(a, b), \ - (f"expected tensors on rank {i} and {i + 1} not to be equal " - f"but they are, {a} vs {b}") + assert not torch.allclose(a, b), ( + f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" + ) + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + assert local_name in ep_name, print(f"{local_name} != {ep_name}") + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) + + assert_close(a, b, rtol=rtol, atol=atol) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index f0795a4c7..1bff21066 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -4,102 +4,75 @@ import torch import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters()) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) + sync_local_from_ep(zero_model, moe_model) - # assert zero model - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.module.named_parameters() - ): - assert zero_name == torch_name - assert torch.allclose(zero_param.data, torch_param.data) - - data = torch.randn(16, 4).cuda() + data = torch.randn(16, 4).bfloat16().cuda() label = torch.randint(0, 4, (16,)).cuda() - torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) - zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) - assert torch.allclose(torch_out, zero_out) - grad_handler.handle_gradient() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + assert torch.allclose(zero_out, moe_out) - for (zero_name, zero_param), (torch_name, torch_param) in zip( - zero_model.module.named_parameters(), torch_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.module.named_parameters(), zero_model.module.named_parameters() ): - assert zero_name == torch_name - zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) - if hasattr(zero_param, "moe_info"): - assert len(zero_grad_list) == 0 - assert torch.allclose(zero_param.grad, torch_param.grad) + assert moe_name == zero_name + moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param)) + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(moe_param, "moe_info"): + assert len(moe_grad_list) == 0 + if stage == 1: + zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape) + else: + zero_grad = zero_grad_list[0].view(moe_param.grad.shape) + assert torch.allclose( + moe_param.grad, zero_grad, atol=1e-5 + ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}" else: - assert len(zero_grad_list) > 0 - torch_grad_list = split_ddp_grad(torch_param.grad, world_size) - if stage == 2: - torch_grad_list = torch_grad_list[local_rank : local_rank + 1] - assert len(zero_grad_list) == len(torch_grad_list) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - assert torch.allclose(zero_grad, torch_grad) + assert len(moe_grad_list) > 0 + assert len(moe_grad_list) == len(zero_grad_list) + for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list): + assert torch.allclose(moe_grad, zero_grad) -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank) - run_zero_test(rank, world_size, stage=1) - run_zero_test(rank, world_size, stage=2) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) +def test_moe_zero_model(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_model(world_size=2) + test_moe_zero_model(world_size=2, stage=1) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 0d2e2fb1b..4f6067aaa 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -4,89 +4,80 @@ import torch import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_optim_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - zero_optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_optimizer = torch.optim.Adam(torch_model.parameters()) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + sync_local_from_ep(zero_model, moe_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) - for _ in range(2): - data = torch.randn(16, 4).cuda() / (local_rank + 1) - label = torch.randint(0, 4, (16,)).cuda() - run_fwd_bwd(torch_model, data, label, criterion, None) - run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - grad_handler.handle_gradient() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() + ): + if ".experts." in moe_name: + continue + assert moe_name == zero_name + assert torch.allclose( + moe_param.data, zero_param.data + ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}" - torch_optimizer.step() + for _ in range(1): + data = torch.randn(2, 4).bfloat16().cuda() + label = torch.randint(0, 4, (2,)).cuda() + + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, moe_out) + moe_optimizer.step() zero_optimizer.step() - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() ): - assert torch.allclose( - torch_param.data, zero_param.data - ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + assert moe_name == zero_name + if is_moe_tensor(moe_param): + param_size = moe_param.shape[0] + zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size] + loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype) - torch_optimizer.zero_grad() + moe_optimizer.zero_grad() zero_optimizer.zero_grad() -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") - run_zero_optim_test(rank, world_size, stage=1) - run_zero_optim_test(rank, world_size, stage=2) + seed_all(42 + rank) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size): - spawn(run_dist, world_size) +def test_moe_zero_optim(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_optim(world_size=2) + test_moe_zero_optim(world_size=2, stage=1)