[test] add mixtral transformer test

This commit is contained in:
hxwang
2024-07-02 09:08:41 +00:00
committed by Hongxin Liu
parent f9b6fcf81f
commit 0b76b57cd6
6 changed files with 281 additions and 30 deletions

View File

@@ -4,8 +4,6 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.mixtral.modeling_mixtral import (
@@ -23,30 +21,34 @@ from colossalai.shardformer.shard.utils import set_tensors_to_none
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
self.moe_info = None
def __init__(self, config, ep_group):
super().__init__(config)
self.setup_ep(ep_group)
def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
def from_native_module(
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
# if "ep_group" in kwargs:
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
module.setup_ep(kwargs["ep_group"])
module.setup_ep(ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: