mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
@@ -28,11 +28,13 @@ from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
|
||||
@@ -58,7 +60,7 @@ class AddAuxiliaryLoss(torch.autograd.Function):
|
||||
return grad_output, grad_loss
|
||||
|
||||
|
||||
class EPDeepseekMoE(nn.Module):
|
||||
class EPDeepseekMoE(ParallelModule):
|
||||
def __init__(self):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
@@ -214,6 +216,79 @@ class EPDeepseekMoE(nn.Module):
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class DeepseekMoEGate_Col(ParallelModule):
|
||||
def parallel_linear(self, hidden_states):
|
||||
assert (
|
||||
hidden_states.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
hidden_states.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
|
||||
output = linear_with_async_comm(
|
||||
hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(
|
||||
output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = self.parallel_linear(hidden_states)
|
||||
if self.scoring_func == "softmax":
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(
|
||||
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
|
||||
).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module, process_group: ProcessGroup, config, gather_output, fp8_communication
|
||||
) -> "DeepseekMoEGate_Col":
|
||||
LazyInitContext.materialize(module)
|
||||
module.process_group = process_group
|
||||
module.fp8_communication = fp8_communication
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, module.weight)
|
||||
module.__class__ = DeepseekMoEGate_Col
|
||||
return module
|
||||
|
||||
|
||||
class DeepseekPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
|
Reference in New Issue
Block a user