mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[test] add mixtral transformer test
This commit is contained in:
@@ -4,8 +4,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
@@ -23,30 +21,34 @@ from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
self.moe_info = None
|
||||
def __init__(self, config, ep_group):
|
||||
super().__init__(config)
|
||||
self.setup_ep(ep_group)
|
||||
|
||||
def setup_ep(self, ep_group: ProcessGroup):
|
||||
ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
|
||||
if self.num_experts % self.ep_size != 0:
|
||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
def from_native_module(
|
||||
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
|
||||
) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
# if "ep_group" in kwargs:
|
||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
||||
module.setup_ep(kwargs["ep_group"])
|
||||
module.setup_ep(ep_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user